the-puzzler commited on
Commit
b9b97d8
·
1 Parent(s): 515a8b4
Files changed (2) hide show
  1. app.py +210 -176
  2. requirements.txt +2 -0
app.py CHANGED
@@ -1,14 +1,16 @@
1
  # app.py
2
- import os, re, math, random
3
  import torch
4
  import torch.nn as nn
5
  import torch.nn.functional as F
6
  import gradio as gr
7
  from transformers import AutoTokenizer
 
 
8
 
9
- # -----------------------------
10
  # Minimal CNA (inference-ready)
11
- # -----------------------------
12
  class AttnBlock(nn.Module):
13
  def __init__(self, embed_dim, num_heads, expansion_factor):
14
  super().__init__()
@@ -28,7 +30,7 @@ class AttnBlock(nn.Module):
28
  nn.Linear(embed_dim * expansion_factor, embed_dim),
29
  )
30
 
31
- # match training's zero-init on residual branches
32
  nn.init.zeros_(self.Wo.weight); nn.init.zeros_(self.Wo.bias)
33
  nn.init.zeros_(self.mlp[-1].weight); nn.init.zeros_(self.mlp[-1].bias)
34
 
@@ -118,35 +120,9 @@ class CNA(nn.Module):
118
  h = blk(h, rope=(cos, sin), radius=self.radius)
119
  return self.proj(h)
120
 
121
- # -----------------------------
122
  # Helpers
123
- # -----------------------------
124
- @torch.no_grad()
125
- def sample_from_logits(logits_row: torch.Tensor, temperature: float = 1.0,
126
- current_token: int | None = None, exclude_current: bool = True) -> int:
127
- """
128
- Sample a token from logits_row using softmax with temperature.
129
- If exclude_current=True and current_token is provided, set its prob to 0 (then renormalize).
130
- """
131
- if temperature <= 0:
132
- # safety: treat as argmax
133
- return int(torch.argmax(logits_row).item())
134
-
135
- scaled = logits_row / float(temperature)
136
- probs = torch.softmax(scaled, dim=-1)
137
-
138
- if exclude_current and current_token is not None:
139
- probs = probs.clone()
140
- probs[current_token] = 0.0
141
- s = probs.sum()
142
- if s.item() <= 0:
143
- # fallback to argmax if everything got zeroed
144
- return int(torch.argmax(logits_row).item())
145
- probs = probs / s
146
-
147
- return int(torch.multinomial(probs, num_samples=1).item())
148
-
149
-
150
  def infer_expansion_factor_from_state(state, embed_dim):
151
  for key in ("blocks.0.mlp.0.weight", "blocks.0.mlp.2.weight"):
152
  if key in state:
@@ -191,7 +167,7 @@ def apply_noise_ops(x, tokenizer, indices_csv, add_noise_left, add_noise_right,
191
 
192
  # noise brush (indices like "0, 5, 6-10")
193
  idxs = set()
194
- if indices_csv.strip():
195
  for part in indices_csv.split(","):
196
  part = part.strip()
197
  if not part:
@@ -210,15 +186,15 @@ def apply_noise_ops(x, tokenizer, indices_csv, add_noise_left, add_noise_right,
210
  except:
211
  continue
212
  for j in idxs:
213
- if 0 <= j < seqlen:
214
  x[0, j] = rnd.randrange(V)
215
 
216
  # prepend/append random noise
217
  if add_noise_left > 0:
218
- prefix = torch.tensor([rnd.randrange(V) for _ in range(add_noise_left)], dtype=torch.long).unsqueeze(0)
219
  x = torch.cat([prefix, x], dim=1)
220
  if add_noise_right > 0:
221
- suffix = torch.tensor([rnd.randrange(V) for _ in range(add_noise_right)], dtype=torch.long).unsqueeze(0)
222
  x = torch.cat([x, suffix], dim=1)
223
 
224
  # force length back to seqlen (trim or pad random)
@@ -230,53 +206,164 @@ def apply_noise_ops(x, tokenizer, indices_csv, add_noise_left, add_noise_right,
230
  x = torch.cat([x, pad], dim=1)
231
  return x
232
 
233
- # -----------------------------
234
- # Load checkpoint & build model
235
- # -----------------------------
236
- DEFAULT_CKPT = os.environ.get("CKPT_PATH", "ckpt_latest.pt")
237
- model_cache = {"model": None, "tokenizer": None, "radius": None, "ckpt": None}
 
 
 
 
 
 
 
 
 
 
238
 
239
- def load_model(ckpt_path: str):
240
- if not os.path.exists(ckpt_path):
241
- raise FileNotFoundError(
242
- f"Checkpoint not found at {ckpt_path}. "
243
- "Upload ckpt_latest.pt to the repo root or set the correct path."
244
- )
245
- payload = torch.load(ckpt_path, map_location="cpu")
246
- state = payload["model"]
247
- cfg = payload.get("config", {}) or {}
248
 
249
- # Carry over config (robust fallbacks)
250
  embed_dim = cfg.get("embed_dim")
251
  num_heads = cfg.get("num_heads")
252
  num_blocks = cfg.get("num_blocks")
253
  radius = cfg.get("radius")
254
  expansion_factor = cfg.get("expansion_factor")
 
255
 
256
- if embed_dim is None: embed_dim = state["tok_emb.weight"].shape[1]
 
257
  if num_blocks is None:
258
  block_idxs = [int(m.group(1)) for k in state.keys() for m in [re.match(r"blocks\.(\d+)\.", k)] if m]
259
  num_blocks = max(block_idxs) + 1 if block_idxs else 1
260
- if num_heads is None: num_heads = 8
261
- if radius is None: radius = 16
 
 
262
  if expansion_factor is None:
263
  expansion_factor = infer_expansion_factor_from_state(state, embed_dim)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
264
  else:
265
- expansion_factor = int(expansion_factor)
 
 
 
 
 
 
 
 
 
266
 
267
- tokenizer_name = payload.get("tokenizer_name", "gpt2")
268
- tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, use_fast=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
269
  if tokenizer.pad_token is None:
270
  tokenizer.pad_token = tokenizer.eos_token
271
  tokenizer.model_max_length = 1_000_000_000
272
  vocab_size = tokenizer.vocab_size
273
 
 
274
  model = CNA(
275
- int(embed_dim), int(num_heads), int(expansion_factor),
276
- int(num_blocks), int(radius), int(vocab_size)
277
  )
278
 
279
- # Load weights (tolerate proj head size diff)
280
  missing, unexpected = model.load_state_dict(state, strict=False)
281
  if any(k.startswith("proj.") for k in missing):
282
  with torch.no_grad():
@@ -286,155 +373,132 @@ def load_model(ckpt_path: str):
286
  model.load_state_dict(state, strict=True)
287
 
288
  model.eval()
289
- return model, tokenizer, int(radius)
290
 
291
- def ensure_model(ckpt_path):
292
- if model_cache["model"] is None or model_cache["ckpt"] != ckpt_path:
293
- m, tok, rad = load_model(ckpt_path)
294
- model_cache.update({"model": m, "tokenizer": tok, "radius": rad, "ckpt": ckpt_path})
295
 
296
- # -----------------------------
297
- # Strategy 1 core step
298
- # -----------------------------
 
 
 
 
 
 
299
  @torch.no_grad()
300
- def step_strategy1(model, x, mode: str = "argmax",
301
- temperature: float = 1.0,
302
- exclude_current: bool = True):
303
- """
304
- One iteration: choose random position, then update via:
305
- - mode="argmax": set token to argmax(logits)
306
- - mode="sample": sample from softmax(logits / temperature)
307
- (optionally excluding current token)
308
- """
309
  S = x.shape[1]
310
  pos = int(torch.randint(0, S, (1,)).item())
311
  logits_pos = model_logits(model, x)[0, pos] # [V]
312
-
313
  if mode == "sample":
314
  cur_tok = int(x[0, pos].item())
315
  new_tok = sample_from_logits(
316
  logits_pos,
317
  temperature=float(temperature),
318
  current_token=cur_tok,
319
- exclude_current=bool(exclude_current)
320
  )
321
  x[0, pos] = new_tok
322
  else:
323
- # default / fallback: argmax
324
  x[0, pos] = int(torch.argmax(logits_pos).item())
325
  return x
326
 
327
-
328
- # -----------------------------
329
- # Gradio logic
330
- # -----------------------------
331
- def init_random(ckpt_path, seqlen, seed):
332
- ensure_model(ckpt_path or DEFAULT_CKPT)
333
  random.seed(seed); torch.manual_seed(seed)
334
  V = model_cache["tokenizer"].vocab_size
335
- x = torch.randint(0, V, (1, seqlen))
336
  txt = decode(x[0], model_cache["tokenizer"])
337
- return x.tolist(), txt, f"Initialized random sequence (len={seqlen})"
338
 
339
- def init_from_text(ckpt_path, seqlen, text, seed, pad_mode):
340
- ensure_model(ckpt_path or DEFAULT_CKPT)
341
  rnd = random.Random(seed)
342
- x = to_fixed_len_ids(text or "", model_cache["tokenizer"], seqlen, pad_mode=pad_mode, rnd=rnd)
343
  txt = decode(x[0], model_cache["tokenizer"])
344
  return x.tolist(), txt, "Initialized from text"
345
 
346
- def append_text(ckpt_path, state_ids, seqlen, text_to_append, seed):
347
- ensure_model(ckpt_path or DEFAULT_CKPT)
348
  tok = model_cache["tokenizer"]
349
  rnd = random.Random(seed)
 
350
  if state_ids is None or len(state_ids) == 0:
351
- x = to_fixed_len_ids(text_to_append or "", tok, seqlen, pad_mode="random", rnd=rnd)
352
  else:
353
  x = torch.tensor(state_ids, dtype=torch.long).unsqueeze(0)
354
- # append
355
  extra = tok.encode(text_to_append or "", add_special_tokens=False)
356
  x = torch.cat([x, torch.tensor(extra, dtype=torch.long).unsqueeze(0)], dim=1)
357
- # force length
358
- if x.shape[1] > seqlen:
359
- x = x[:, :seqlen]
360
- elif x.shape[1] < seqlen:
361
- need = seqlen - x.shape[1]
362
  V = tok.vocab_size
363
  pad = torch.tensor([rnd.randrange(V) for _ in range(need)], dtype=torch.long).unsqueeze(0)
364
  x = torch.cat([x, pad], dim=1)
365
  txt = decode(x[0], tok)
366
  return x.tolist(), txt, "Appended text and resized to target length"
367
 
368
- def apply_noise(ckpt_path, state_ids, seqlen, indices_csv, add_left, add_right, seed):
369
- ensure_model(ckpt_path or DEFAULT_CKPT)
370
  tok = model_cache["tokenizer"]
 
371
  if state_ids is None or len(state_ids) == 0:
372
- # create an empty base (random) then apply ops
373
  V = tok.vocab_size
374
- base = torch.randint(0, V, (1, seqlen))
375
  else:
376
  base = torch.tensor(state_ids, dtype=torch.long).unsqueeze(0)
377
- x = apply_noise_ops(base, tok, indices_csv, int(add_left), int(add_right), seqlen, seed=seed)
378
  txt = decode(x[0], tok)
379
  return x.tolist(), txt, "Applied noise brush / prepend / append"
380
 
381
- def step_once(ckpt_path, state_ids, mode, temperature, exclude_current):
382
- ensure_model(ckpt_path or DEFAULT_CKPT)
383
  tok = model_cache["tokenizer"]
384
  if state_ids is None or len(state_ids) == 0:
385
  return None, "", "No sequence to step — initialize first."
386
  x = torch.tensor(state_ids, dtype=torch.long).unsqueeze(0)
387
- x = step_strategy1(
388
- model_cache["model"], x,
389
- mode=mode,
390
- temperature=temperature,
391
- exclude_current=exclude_current
392
- )
393
  txt = decode(x[0], tok)
394
  return x.tolist(), txt, f"Stepped 1 iteration ({mode})"
395
 
396
- def live_denoise(ckpt_path, state_ids, steps, snap_every, seed,
397
- mode, temperature, exclude_current):
398
- """
399
- Generator for live updates. Yields (ids, text, status) every snap_every steps and on completion.
400
- """
401
- ensure_model(ckpt_path or DEFAULT_CKPT)
402
  tok = model_cache["tokenizer"]
403
  if state_ids is None or len(state_ids) == 0:
404
  return
405
  random.seed(seed); torch.manual_seed(seed)
406
  x = torch.tensor(state_ids, dtype=torch.long).unsqueeze(0)
407
- total = int(steps)
408
- snap = max(1, int(snap_every))
409
  for t in range(1, total + 1):
410
- x = step_strategy1(
411
- model_cache["model"], x,
412
- mode=mode,
413
- temperature=temperature,
414
- exclude_current=exclude_current
415
- )
416
  if (t % snap == 0) or (t == total):
417
  txt = decode(x[0], tok)
418
  yield x.tolist(), txt, f"Live denoise… step {t}/{total} ({mode})"
419
 
420
-
421
- # -----------------------------
422
  # UI
423
- # -----------------------------
424
- with gr.Blocks(title="CNA — Interactive Denoising (Strategy 1)") as demo:
425
  gr.Markdown(
426
  """
427
  # CNA — Interactive Denoising (Strategy 1)
428
- - **Mode 1:** Randomize then watch it **denoise live** (random-position argmax).
429
- - **Mode 2:** Initialize from **your text**.
430
- - **Noise Brush:** Select positions (e.g., `0, 5, 10-20`), and/or add random noise tokens at **start**/**end**.
431
- - **Append:** Add your text to the current sequence.
432
  """
433
  )
434
 
435
  # Global settings
 
436
  with gr.Row():
437
- ckpt = gr.Textbox(value=DEFAULT_CKPT, label="Checkpoint path")
438
  seqlen = gr.Slider(10, 512, value=100, step=1, label="Sequence length (S)")
439
  seed = gr.Slider(0, 10000, value=0, step=1, label="Seed")
440
 
@@ -447,24 +511,14 @@ with gr.Blocks(title="CNA — Interactive Denoising (Strategy 1)") as demo:
447
  status = gr.Markdown("Ready.")
448
 
449
  gr.Markdown("## Mode 1 · Random → Denoise Live")
450
- with gr.Row():
451
- update_mode = gr.Radio(
452
- choices=["argmax", "sample"],
453
- value="argmax",
454
- label="Update rule"
455
- )
456
- temperature = gr.Slider(
457
- minimum=0.0, maximum=5.0, value=1.0, step=0.05,
458
- label="Temperature (sampling)"
459
- )
460
- exclude_current = gr.Checkbox(
461
- value=True,
462
- label="Exclude current token when sampling"
463
- )
464
  with gr.Row():
465
  btn_random = gr.Button("Initialize Random")
466
  steps = gr.Slider(1, 2000, value=200, step=1, label="Denoise steps (N)")
467
  snap_every = gr.Slider(1, 100, value=5, step=1, label="Update every K steps")
 
 
 
 
468
  with gr.Row():
469
  btn_step_once = gr.Button("Step Once")
470
  btn_live = gr.Button("Denoise Live (streaming)")
@@ -490,48 +544,28 @@ with gr.Blocks(title="CNA — Interactive Denoising (Strategy 1)") as demo:
490
  btn_append = gr.Button("Append to Current Sequence")
491
 
492
  # --- Wiring ---
493
- # Random init
494
- out = btn_random.click(
495
- init_random,
496
- [ckpt, seqlen, seed],
497
- [ids_state, current_text, status]
498
- )
499
 
500
- # Init from text
501
- btn_init_text.click(
502
- init_from_text,
503
- [ckpt, seqlen, init_text, seed, pad_mode],
504
- [ids_state, current_text, status]
505
- )
506
 
507
- # Apply noise
508
  btn_apply_noise.click(
509
- apply_noise,
510
- [ckpt, ids_state, seqlen, indices_csv, add_left, add_right, seed],
511
  [ids_state, current_text, status]
512
  )
513
 
514
- # Append text
515
- btn_append.click(
516
- append_text,
517
- [ckpt, ids_state, seqlen, append_box, seed],
518
- [ids_state, current_text, status]
519
- )
520
 
521
- # Single step
522
  btn_step_once.click(
523
  step_once,
524
- [ckpt, ids_state, update_mode, temperature, exclude_current],
525
  [ids_state, current_text, status]
526
  )
527
 
528
- # Live denoise (streaming)
529
  btn_live.click(
530
  live_denoise,
531
- [ckpt, ids_state, steps, snap_every, seed, update_mode, temperature, exclude_current],
532
  [ids_state, current_text, status],
533
  show_progress=True
534
  )
535
 
536
  demo.queue().launch()
537
-
 
1
  # app.py
2
+ import os, re, math, random, json
3
  import torch
4
  import torch.nn as nn
5
  import torch.nn.functional as F
6
  import gradio as gr
7
  from transformers import AutoTokenizer
8
+ from safetensors.torch import load_file as load_sft
9
+ from huggingface_hub import snapshot_download
10
 
11
+ # ============================================================
12
  # Minimal CNA (inference-ready)
13
+ # ============================================================
14
  class AttnBlock(nn.Module):
15
  def __init__(self, embed_dim, num_heads, expansion_factor):
16
  super().__init__()
 
30
  nn.Linear(embed_dim * expansion_factor, embed_dim),
31
  )
32
 
33
+ # zero-init on residual branches (to match training behavior)
34
  nn.init.zeros_(self.Wo.weight); nn.init.zeros_(self.Wo.bias)
35
  nn.init.zeros_(self.mlp[-1].weight); nn.init.zeros_(self.mlp[-1].bias)
36
 
 
120
  h = blk(h, rope=(cos, sin), radius=self.radius)
121
  return self.proj(h)
122
 
123
+ # ============================================================
124
  # Helpers
125
+ # ============================================================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
  def infer_expansion_factor_from_state(state, embed_dim):
127
  for key in ("blocks.0.mlp.0.weight", "blocks.0.mlp.2.weight"):
128
  if key in state:
 
167
 
168
  # noise brush (indices like "0, 5, 6-10")
169
  idxs = set()
170
+ if indices_csv and indices_csv.strip():
171
  for part in indices_csv.split(","):
172
  part = part.strip()
173
  if not part:
 
186
  except:
187
  continue
188
  for j in idxs:
189
+ if 0 <= j < x.shape[1]:
190
  x[0, j] = rnd.randrange(V)
191
 
192
  # prepend/append random noise
193
  if add_noise_left > 0:
194
+ prefix = torch.tensor([rnd.randrange(V) for _ in range(int(add_noise_left))], dtype=torch.long).unsqueeze(0)
195
  x = torch.cat([prefix, x], dim=1)
196
  if add_noise_right > 0:
197
+ suffix = torch.tensor([rnd.randrange(V) for _ in range(int(add_noise_right))], dtype=torch.long).unsqueeze(0)
198
  x = torch.cat([x, suffix], dim=1)
199
 
200
  # force length back to seqlen (trim or pad random)
 
206
  x = torch.cat([x, pad], dim=1)
207
  return x
208
 
209
+ @torch.no_grad()
210
+ def sample_from_logits(logits_row, temperature=1.0, current_token=None, exclude_current=True):
211
+ """Temperature sampling; optionally exclude current token to force change."""
212
+ if temperature <= 0:
213
+ return int(torch.argmax(logits_row).item())
214
+ scaled = logits_row / float(temperature)
215
+ probs = torch.softmax(scaled, dim=-1)
216
+ if exclude_current and current_token is not None:
217
+ probs = probs.clone()
218
+ probs[current_token] = 0.0
219
+ s = probs.sum()
220
+ if s.item() <= 0:
221
+ return int(torch.argmax(logits_row).item())
222
+ probs = probs / s
223
+ return int(torch.multinomial(probs, 1).item())
224
 
225
+ # ============================================================
226
+ # Weight loading: file, folder, or Hub repo
227
+ # ============================================================
228
+ DEFAULT_CKPT = os.environ.get("CKPT_PATH", "ckpt_latest.pt")
229
+ DEFAULT_WEIGHTS_DIR = os.environ.get("WEIGHTS_DIR", "weights_latest")
 
 
 
 
230
 
231
+ def _read_config_from_dict_or_infer(state, cfg):
232
  embed_dim = cfg.get("embed_dim")
233
  num_heads = cfg.get("num_heads")
234
  num_blocks = cfg.get("num_blocks")
235
  radius = cfg.get("radius")
236
  expansion_factor = cfg.get("expansion_factor")
237
+ tokenizer_name = cfg.get("tokenizer_name", cfg.get("tokenizer") or "gpt2")
238
 
239
+ if embed_dim is None:
240
+ embed_dim = state["tok_emb.weight"].shape[1]
241
  if num_blocks is None:
242
  block_idxs = [int(m.group(1)) for k in state.keys() for m in [re.match(r"blocks\.(\d+)\.", k)] if m]
243
  num_blocks = max(block_idxs) + 1 if block_idxs else 1
244
+ if num_heads is None:
245
+ num_heads = 8
246
+ if radius is None:
247
+ radius = 16
248
  if expansion_factor is None:
249
  expansion_factor = infer_expansion_factor_from_state(state, embed_dim)
250
+
251
+ return {
252
+ "embed_dim": int(embed_dim),
253
+ "num_heads": int(num_heads),
254
+ "num_blocks": int(num_blocks),
255
+ "radius": int(radius),
256
+ "expansion_factor": int(expansion_factor),
257
+ "tokenizer_name": tokenizer_name,
258
+ }
259
+
260
+ def _load_state_from_pt(payload_path: str):
261
+ payload = torch.load(payload_path, map_location="cpu")
262
+ state = payload["model"]
263
+ cfg = payload.get("config", {}) or {}
264
+ if "tokenizer_name" in payload:
265
+ cfg = {**cfg, "tokenizer_name": payload["tokenizer_name"]}
266
+ return state, cfg
267
+
268
+ def _merge_state_dicts(dicts):
269
+ merged = {}
270
+ for d in dicts:
271
+ for k, v in d.items():
272
+ merged[k] = v
273
+ return merged
274
+
275
+ def _load_state_from_folder(weights_dir: str):
276
+ if not os.path.isdir(weights_dir):
277
+ raise FileNotFoundError(f"Folder not found: {weights_dir}")
278
+
279
+ cfg_path = os.path.join(weights_dir, "config.json")
280
+ cfg = {}
281
+ if os.path.exists(cfg_path):
282
+ with open(cfg_path, "r") as f:
283
+ cfg = json.load(f)
284
+
285
+ files = sorted(os.listdir(weights_dir))
286
+ sft_files = [f for f in files if f.endswith(".safetensors")]
287
+ pt_files = [f for f in files if f.endswith(".pt") or f.endswith(".bin")]
288
+
289
+ state = None
290
+ if "model.safetensors" in sft_files:
291
+ state = load_sft(os.path.join(weights_dir, "model.safetensors"))
292
+ elif sft_files:
293
+ parts = [load_sft(os.path.join(weights_dir, f)) for f in sft_files]
294
+ state = _merge_state_dicts(parts)
295
+ elif pt_files:
296
+ parts = []
297
+ for f in pt_files:
298
+ part = torch.load(os.path.join(weights_dir, f), map_location="cpu")
299
+ if isinstance(part, dict) and "model" in part and isinstance(part["model"], dict):
300
+ parts.append(part["model"])
301
+ if "config" in part and isinstance(part["config"], dict):
302
+ cfg = {**cfg, **part["config"]}
303
+ if "tokenizer_name" in part:
304
+ cfg.setdefault("tokenizer_name", part["tokenizer_name"])
305
+ else:
306
+ parts.append(part)
307
+ state = _merge_state_dicts(parts)
308
  else:
309
+ raise FileNotFoundError(
310
+ f"No weights found in {weights_dir}. Expected .safetensors or .pt files."
311
+ )
312
+
313
+ return state, cfg
314
+
315
+ def _load_state_from_hub(repo_id: str, subfolder: str | None = None, revision: str | None = None):
316
+ cache_dir = snapshot_download(repo_id=repo_id, revision=revision, allow_patterns=None)
317
+ path = os.path.join(cache_dir, subfolder) if subfolder else cache_dir
318
+ return _load_state_from_folder(path)
319
 
320
+ def load_model(source: str):
321
+ """
322
+ `source` can be:
323
+ - Path to single-file checkpoint: 'ckpt_latest.pt'
324
+ - Path to folder of shards: 'weights_latest'
325
+ - HF Hub repo id: 'org/model'
326
+ """
327
+ # Resolve source
328
+ src = source or ""
329
+ state, cfg = None, {}
330
+
331
+ if os.path.isfile(src) and src.endswith(".pt"):
332
+ state, cfg = _load_state_from_pt(src)
333
+ elif os.path.isdir(src):
334
+ state, cfg = _load_state_from_folder(src)
335
+ elif "/" in src: # probably a hub repo id
336
+ subfolder = os.environ.get("WEIGHTS_SUBFOLDER") or None
337
+ revision = os.environ.get("WEIGHTS_REVISION") or None
338
+ state, cfg = _load_state_from_hub(src, subfolder=subfolder, revision=revision)
339
+ else:
340
+ # fallbacks
341
+ if os.path.isfile(DEFAULT_CKPT):
342
+ state, cfg = _load_state_from_pt(DEFAULT_CKPT)
343
+ elif os.path.isdir(DEFAULT_WEIGHTS_DIR):
344
+ state, cfg = _load_state_from_folder(DEFAULT_WEIGHTS_DIR)
345
+ else:
346
+ raise FileNotFoundError(
347
+ f"Could not resolve weights from '{src}'. Tried file (.pt), folder, hub repo id, "
348
+ f"then defaults ('{DEFAULT_CKPT}', '{DEFAULT_WEIGHTS_DIR}')."
349
+ )
350
+
351
+ conf = _read_config_from_dict_or_infer(state, cfg)
352
+
353
+ # Tokenizer
354
+ tokenizer = AutoTokenizer.from_pretrained(conf["tokenizer_name"], use_fast=True)
355
  if tokenizer.pad_token is None:
356
  tokenizer.pad_token = tokenizer.eos_token
357
  tokenizer.model_max_length = 1_000_000_000
358
  vocab_size = tokenizer.vocab_size
359
 
360
+ # Build model
361
  model = CNA(
362
+ conf["embed_dim"], conf["num_heads"], conf["expansion_factor"],
363
+ conf["num_blocks"], conf["radius"], vocab_size
364
  )
365
 
366
+ # Load state (tolerate projection size mismatch)
367
  missing, unexpected = model.load_state_dict(state, strict=False)
368
  if any(k.startswith("proj.") for k in missing):
369
  with torch.no_grad():
 
373
  model.load_state_dict(state, strict=True)
374
 
375
  model.eval()
376
+ return model, tokenizer, conf["radius"]
377
 
378
+ model_cache = {"model": None, "tokenizer": None, "radius": None, "ckpt": None}
 
 
 
379
 
380
+ def ensure_model(source_path_or_repo):
381
+ src = source_path_or_repo or os.environ.get("WEIGHTS_SOURCE") or DEFAULT_WEIGHTS_DIR
382
+ if model_cache["model"] is None or model_cache["ckpt"] != src:
383
+ m, tok, rad = load_model(src)
384
+ model_cache.update({"model": m, "tokenizer": tok, "radius": rad, "ckpt": src})
385
+
386
+ # ============================================================
387
+ # Strategy 1 core step (with argmax / sample toggle)
388
+ # ============================================================
389
  @torch.no_grad()
390
+ def step_strategy1(model, x, mode="argmax", temperature=1.0, exclude_current=True):
391
+ """One iteration: choose random position, update via argmax or sampling."""
 
 
 
 
 
 
 
392
  S = x.shape[1]
393
  pos = int(torch.randint(0, S, (1,)).item())
394
  logits_pos = model_logits(model, x)[0, pos] # [V]
 
395
  if mode == "sample":
396
  cur_tok = int(x[0, pos].item())
397
  new_tok = sample_from_logits(
398
  logits_pos,
399
  temperature=float(temperature),
400
  current_token=cur_tok,
401
+ exclude_current=bool(exclude_current),
402
  )
403
  x[0, pos] = new_tok
404
  else:
 
405
  x[0, pos] = int(torch.argmax(logits_pos).item())
406
  return x
407
 
408
+ # ============================================================
409
+ # Gradio callbacks
410
+ # ============================================================
411
+ def init_random(src, seqlen, seed):
412
+ ensure_model(src)
 
413
  random.seed(seed); torch.manual_seed(seed)
414
  V = model_cache["tokenizer"].vocab_size
415
+ x = torch.randint(0, V, (1, int(seqlen)))
416
  txt = decode(x[0], model_cache["tokenizer"])
417
+ return x.tolist(), txt, f"Initialized random sequence (len={int(seqlen)})"
418
 
419
+ def init_from_text(src, seqlen, text, seed, pad_mode):
420
+ ensure_model(src)
421
  rnd = random.Random(seed)
422
+ x = to_fixed_len_ids(text or "", model_cache["tokenizer"], int(seqlen), pad_mode=pad_mode, rnd=rnd)
423
  txt = decode(x[0], model_cache["tokenizer"])
424
  return x.tolist(), txt, "Initialized from text"
425
 
426
+ def append_text(src, state_ids, seqlen, text_to_append, seed):
427
+ ensure_model(src)
428
  tok = model_cache["tokenizer"]
429
  rnd = random.Random(seed)
430
+ S = int(seqlen)
431
  if state_ids is None or len(state_ids) == 0:
432
+ x = to_fixed_len_ids(text_to_append or "", tok, S, pad_mode="random", rnd=rnd)
433
  else:
434
  x = torch.tensor(state_ids, dtype=torch.long).unsqueeze(0)
 
435
  extra = tok.encode(text_to_append or "", add_special_tokens=False)
436
  x = torch.cat([x, torch.tensor(extra, dtype=torch.long).unsqueeze(0)], dim=1)
437
+ if x.shape[1] > S:
438
+ x = x[:, :S]
439
+ elif x.shape[1] < S:
440
+ need = S - x.shape[1]
 
441
  V = tok.vocab_size
442
  pad = torch.tensor([rnd.randrange(V) for _ in range(need)], dtype=torch.long).unsqueeze(0)
443
  x = torch.cat([x, pad], dim=1)
444
  txt = decode(x[0], tok)
445
  return x.tolist(), txt, "Appended text and resized to target length"
446
 
447
+ def apply_noise(src, state_ids, seqlen, indices_csv, add_left, add_right, seed):
448
+ ensure_model(src)
449
  tok = model_cache["tokenizer"]
450
+ S = int(seqlen)
451
  if state_ids is None or len(state_ids) == 0:
 
452
  V = tok.vocab_size
453
+ base = torch.randint(0, V, (1, S))
454
  else:
455
  base = torch.tensor(state_ids, dtype=torch.long).unsqueeze(0)
456
+ x = apply_noise_ops(base, tok, indices_csv, int(add_left or 0), int(add_right or 0), S, seed=seed)
457
  txt = decode(x[0], tok)
458
  return x.tolist(), txt, "Applied noise brush / prepend / append"
459
 
460
+ def step_once(src, state_ids, mode, temperature, exclude_current):
461
+ ensure_model(src)
462
  tok = model_cache["tokenizer"]
463
  if state_ids is None or len(state_ids) == 0:
464
  return None, "", "No sequence to step — initialize first."
465
  x = torch.tensor(state_ids, dtype=torch.long).unsqueeze(0)
466
+ x = step_strategy1(model_cache["model"], x, mode=mode, temperature=temperature, exclude_current=exclude_current)
 
 
 
 
 
467
  txt = decode(x[0], tok)
468
  return x.tolist(), txt, f"Stepped 1 iteration ({mode})"
469
 
470
+ def live_denoise(src, state_ids, steps, snap_every, seed, mode, temperature, exclude_current):
471
+ """Generator: yields (ids, text, status) every snap_every steps & on completion."""
472
+ ensure_model(src)
 
 
 
473
  tok = model_cache["tokenizer"]
474
  if state_ids is None or len(state_ids) == 0:
475
  return
476
  random.seed(seed); torch.manual_seed(seed)
477
  x = torch.tensor(state_ids, dtype=torch.long).unsqueeze(0)
478
+ total = int(steps); snap = max(1, int(snap_every))
 
479
  for t in range(1, total + 1):
480
+ x = step_strategy1(model_cache["model"], x, mode=mode, temperature=temperature, exclude_current=exclude_current)
 
 
 
 
 
481
  if (t % snap == 0) or (t == total):
482
  txt = decode(x[0], tok)
483
  yield x.tolist(), txt, f"Live denoise… step {t}/{total} ({mode})"
484
 
485
+ # ============================================================
 
486
  # UI
487
+ # ============================================================
488
+ with gr.Blocks(title="CNA — Interactive Denoising") as demo:
489
  gr.Markdown(
490
  """
491
  # CNA — Interactive Denoising (Strategy 1)
492
+ - **Weights source** can be: a `.pt` file, a folder like `weights_latest/` (safetensors or .pt shards), or a **Hub repo id** `org/model`.
493
+ - Update rule per step: **argmax** or **sample** (temperature + option to exclude current token).
494
+ - Tools: Random init, Init from text, Noise brush (select indices, prepend/append noise), Append text, Live denoise.
 
495
  """
496
  )
497
 
498
  # Global settings
499
+ default_source = os.environ.get("WEIGHTS_SOURCE", DEFAULT_WEIGHTS_DIR if os.path.isdir(DEFAULT_WEIGHTS_DIR) else DEFAULT_CKPT)
500
  with gr.Row():
501
+ src = gr.Textbox(value=default_source, label="Weights (file / folder / HF repo id)")
502
  seqlen = gr.Slider(10, 512, value=100, step=1, label="Sequence length (S)")
503
  seed = gr.Slider(0, 10000, value=0, step=1, label="Seed")
504
 
 
511
  status = gr.Markdown("Ready.")
512
 
513
  gr.Markdown("## Mode 1 · Random → Denoise Live")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
514
  with gr.Row():
515
  btn_random = gr.Button("Initialize Random")
516
  steps = gr.Slider(1, 2000, value=200, step=1, label="Denoise steps (N)")
517
  snap_every = gr.Slider(1, 100, value=5, step=1, label="Update every K steps")
518
+ with gr.Row():
519
+ update_mode = gr.Radio(choices=["argmax", "sample"], value="argmax", label="Update rule")
520
+ temperature = gr.Slider(minimum=0.0, maximum=5.0, value=1.0, step=0.05, label="Temperature (sampling)")
521
+ exclude_current = gr.Checkbox(value=True, label="Exclude current token when sampling")
522
  with gr.Row():
523
  btn_step_once = gr.Button("Step Once")
524
  btn_live = gr.Button("Denoise Live (streaming)")
 
544
  btn_append = gr.Button("Append to Current Sequence")
545
 
546
  # --- Wiring ---
547
+ btn_random.click(init_random, [src, seqlen, seed], [ids_state, current_text, status])
 
 
 
 
 
548
 
549
+ btn_init_text.click(init_from_text, [src, seqlen, init_text, seed, pad_mode], [ids_state, current_text, status])
 
 
 
 
 
550
 
 
551
  btn_apply_noise.click(
552
+ apply_noise, [src, ids_state, seqlen, indices_csv, add_left, add_right, seed],
 
553
  [ids_state, current_text, status]
554
  )
555
 
556
+ btn_append.click(append_text, [src, ids_state, seqlen, append_box, seed], [ids_state, current_text, status])
 
 
 
 
 
557
 
 
558
  btn_step_once.click(
559
  step_once,
560
+ [src, ids_state, update_mode, temperature, exclude_current],
561
  [ids_state, current_text, status]
562
  )
563
 
 
564
  btn_live.click(
565
  live_denoise,
566
+ [src, ids_state, steps, snap_every, seed, update_mode, temperature, exclude_current],
567
  [ids_state, current_text, status],
568
  show_progress=True
569
  )
570
 
571
  demo.queue().launch()
 
requirements.txt CHANGED
@@ -1,3 +1,5 @@
1
  torch --extra-index-url https://download.pytorch.org/whl/cpu
2
  transformers>=4.41.0
3
  gradio>=4.31.0
 
 
 
1
  torch --extra-index-url https://download.pytorch.org/whl/cpu
2
  transformers>=4.41.0
3
  gradio>=4.31.0
4
+ safetensors>=0.4.2
5
+ huggingface_hub>=0.23.0