Spaces:
Sleeping
Sleeping
the-puzzler
commited on
Commit
·
515a8b4
1
Parent(s):
fbcf0db
added differnt argmax or sampling lgoits
Browse files
app.py
CHANGED
|
@@ -121,6 +121,32 @@ class CNA(nn.Module):
|
|
| 121 |
# -----------------------------
|
| 122 |
# Helpers
|
| 123 |
# -----------------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 124 |
def infer_expansion_factor_from_state(state, embed_dim):
|
| 125 |
for key in ("blocks.0.mlp.0.weight", "blocks.0.mlp.2.weight"):
|
| 126 |
if key in state:
|
|
@@ -271,14 +297,34 @@ def ensure_model(ckpt_path):
|
|
| 271 |
# Strategy 1 core step
|
| 272 |
# -----------------------------
|
| 273 |
@torch.no_grad()
|
| 274 |
-
def step_strategy1(model, x
|
| 275 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 276 |
S = x.shape[1]
|
| 277 |
pos = int(torch.randint(0, S, (1,)).item())
|
| 278 |
logits_pos = model_logits(model, x)[0, pos] # [V]
|
| 279 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 280 |
return x
|
| 281 |
|
|
|
|
| 282 |
# -----------------------------
|
| 283 |
# Gradio logic
|
| 284 |
# -----------------------------
|
|
@@ -332,17 +378,23 @@ def apply_noise(ckpt_path, state_ids, seqlen, indices_csv, add_left, add_right,
|
|
| 332 |
txt = decode(x[0], tok)
|
| 333 |
return x.tolist(), txt, "Applied noise brush / prepend / append"
|
| 334 |
|
| 335 |
-
def step_once(ckpt_path, state_ids):
|
| 336 |
ensure_model(ckpt_path or DEFAULT_CKPT)
|
| 337 |
tok = model_cache["tokenizer"]
|
| 338 |
if state_ids is None or len(state_ids) == 0:
|
| 339 |
return None, "", "No sequence to step — initialize first."
|
| 340 |
x = torch.tensor(state_ids, dtype=torch.long).unsqueeze(0)
|
| 341 |
-
x = step_strategy1(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 342 |
txt = decode(x[0], tok)
|
| 343 |
-
return x.tolist(), txt, "Stepped 1 iteration"
|
| 344 |
|
| 345 |
-
def live_denoise(ckpt_path, state_ids, steps, snap_every, seed
|
|
|
|
| 346 |
"""
|
| 347 |
Generator for live updates. Yields (ids, text, status) every snap_every steps and on completion.
|
| 348 |
"""
|
|
@@ -355,11 +407,16 @@ def live_denoise(ckpt_path, state_ids, steps, snap_every, seed):
|
|
| 355 |
total = int(steps)
|
| 356 |
snap = max(1, int(snap_every))
|
| 357 |
for t in range(1, total + 1):
|
| 358 |
-
x = step_strategy1(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 359 |
if (t % snap == 0) or (t == total):
|
| 360 |
txt = decode(x[0], tok)
|
| 361 |
-
yield x.tolist(), txt, f"Live denoise… step {t}/{total}"
|
| 362 |
-
|
| 363 |
|
| 364 |
# -----------------------------
|
| 365 |
# UI
|
|
@@ -390,6 +447,20 @@ with gr.Blocks(title="CNA — Interactive Denoising (Strategy 1)") as demo:
|
|
| 390 |
status = gr.Markdown("Ready.")
|
| 391 |
|
| 392 |
gr.Markdown("## Mode 1 · Random → Denoise Live")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 393 |
with gr.Row():
|
| 394 |
btn_random = gr.Button("Initialize Random")
|
| 395 |
steps = gr.Slider(1, 2000, value=200, step=1, label="Denoise steps (N)")
|
|
@@ -450,14 +521,14 @@ with gr.Blocks(title="CNA — Interactive Denoising (Strategy 1)") as demo:
|
|
| 450 |
# Single step
|
| 451 |
btn_step_once.click(
|
| 452 |
step_once,
|
| 453 |
-
[ckpt, ids_state],
|
| 454 |
[ids_state, current_text, status]
|
| 455 |
)
|
| 456 |
|
| 457 |
# Live denoise (streaming)
|
| 458 |
btn_live.click(
|
| 459 |
live_denoise,
|
| 460 |
-
[ckpt, ids_state, steps, snap_every, seed],
|
| 461 |
[ids_state, current_text, status],
|
| 462 |
show_progress=True
|
| 463 |
)
|
|
|
|
| 121 |
# -----------------------------
|
| 122 |
# Helpers
|
| 123 |
# -----------------------------
|
| 124 |
+
@torch.no_grad()
|
| 125 |
+
def sample_from_logits(logits_row: torch.Tensor, temperature: float = 1.0,
|
| 126 |
+
current_token: int | None = None, exclude_current: bool = True) -> int:
|
| 127 |
+
"""
|
| 128 |
+
Sample a token from logits_row using softmax with temperature.
|
| 129 |
+
If exclude_current=True and current_token is provided, set its prob to 0 (then renormalize).
|
| 130 |
+
"""
|
| 131 |
+
if temperature <= 0:
|
| 132 |
+
# safety: treat as argmax
|
| 133 |
+
return int(torch.argmax(logits_row).item())
|
| 134 |
+
|
| 135 |
+
scaled = logits_row / float(temperature)
|
| 136 |
+
probs = torch.softmax(scaled, dim=-1)
|
| 137 |
+
|
| 138 |
+
if exclude_current and current_token is not None:
|
| 139 |
+
probs = probs.clone()
|
| 140 |
+
probs[current_token] = 0.0
|
| 141 |
+
s = probs.sum()
|
| 142 |
+
if s.item() <= 0:
|
| 143 |
+
# fallback to argmax if everything got zeroed
|
| 144 |
+
return int(torch.argmax(logits_row).item())
|
| 145 |
+
probs = probs / s
|
| 146 |
+
|
| 147 |
+
return int(torch.multinomial(probs, num_samples=1).item())
|
| 148 |
+
|
| 149 |
+
|
| 150 |
def infer_expansion_factor_from_state(state, embed_dim):
|
| 151 |
for key in ("blocks.0.mlp.0.weight", "blocks.0.mlp.2.weight"):
|
| 152 |
if key in state:
|
|
|
|
| 297 |
# Strategy 1 core step
|
| 298 |
# -----------------------------
|
| 299 |
@torch.no_grad()
|
| 300 |
+
def step_strategy1(model, x, mode: str = "argmax",
|
| 301 |
+
temperature: float = 1.0,
|
| 302 |
+
exclude_current: bool = True):
|
| 303 |
+
"""
|
| 304 |
+
One iteration: choose random position, then update via:
|
| 305 |
+
- mode="argmax": set token to argmax(logits)
|
| 306 |
+
- mode="sample": sample from softmax(logits / temperature)
|
| 307 |
+
(optionally excluding current token)
|
| 308 |
+
"""
|
| 309 |
S = x.shape[1]
|
| 310 |
pos = int(torch.randint(0, S, (1,)).item())
|
| 311 |
logits_pos = model_logits(model, x)[0, pos] # [V]
|
| 312 |
+
|
| 313 |
+
if mode == "sample":
|
| 314 |
+
cur_tok = int(x[0, pos].item())
|
| 315 |
+
new_tok = sample_from_logits(
|
| 316 |
+
logits_pos,
|
| 317 |
+
temperature=float(temperature),
|
| 318 |
+
current_token=cur_tok,
|
| 319 |
+
exclude_current=bool(exclude_current)
|
| 320 |
+
)
|
| 321 |
+
x[0, pos] = new_tok
|
| 322 |
+
else:
|
| 323 |
+
# default / fallback: argmax
|
| 324 |
+
x[0, pos] = int(torch.argmax(logits_pos).item())
|
| 325 |
return x
|
| 326 |
|
| 327 |
+
|
| 328 |
# -----------------------------
|
| 329 |
# Gradio logic
|
| 330 |
# -----------------------------
|
|
|
|
| 378 |
txt = decode(x[0], tok)
|
| 379 |
return x.tolist(), txt, "Applied noise brush / prepend / append"
|
| 380 |
|
| 381 |
+
def step_once(ckpt_path, state_ids, mode, temperature, exclude_current):
|
| 382 |
ensure_model(ckpt_path or DEFAULT_CKPT)
|
| 383 |
tok = model_cache["tokenizer"]
|
| 384 |
if state_ids is None or len(state_ids) == 0:
|
| 385 |
return None, "", "No sequence to step — initialize first."
|
| 386 |
x = torch.tensor(state_ids, dtype=torch.long).unsqueeze(0)
|
| 387 |
+
x = step_strategy1(
|
| 388 |
+
model_cache["model"], x,
|
| 389 |
+
mode=mode,
|
| 390 |
+
temperature=temperature,
|
| 391 |
+
exclude_current=exclude_current
|
| 392 |
+
)
|
| 393 |
txt = decode(x[0], tok)
|
| 394 |
+
return x.tolist(), txt, f"Stepped 1 iteration ({mode})"
|
| 395 |
|
| 396 |
+
def live_denoise(ckpt_path, state_ids, steps, snap_every, seed,
|
| 397 |
+
mode, temperature, exclude_current):
|
| 398 |
"""
|
| 399 |
Generator for live updates. Yields (ids, text, status) every snap_every steps and on completion.
|
| 400 |
"""
|
|
|
|
| 407 |
total = int(steps)
|
| 408 |
snap = max(1, int(snap_every))
|
| 409 |
for t in range(1, total + 1):
|
| 410 |
+
x = step_strategy1(
|
| 411 |
+
model_cache["model"], x,
|
| 412 |
+
mode=mode,
|
| 413 |
+
temperature=temperature,
|
| 414 |
+
exclude_current=exclude_current
|
| 415 |
+
)
|
| 416 |
if (t % snap == 0) or (t == total):
|
| 417 |
txt = decode(x[0], tok)
|
| 418 |
+
yield x.tolist(), txt, f"Live denoise… step {t}/{total} ({mode})"
|
| 419 |
+
|
| 420 |
|
| 421 |
# -----------------------------
|
| 422 |
# UI
|
|
|
|
| 447 |
status = gr.Markdown("Ready.")
|
| 448 |
|
| 449 |
gr.Markdown("## Mode 1 · Random → Denoise Live")
|
| 450 |
+
with gr.Row():
|
| 451 |
+
update_mode = gr.Radio(
|
| 452 |
+
choices=["argmax", "sample"],
|
| 453 |
+
value="argmax",
|
| 454 |
+
label="Update rule"
|
| 455 |
+
)
|
| 456 |
+
temperature = gr.Slider(
|
| 457 |
+
minimum=0.0, maximum=5.0, value=1.0, step=0.05,
|
| 458 |
+
label="Temperature (sampling)"
|
| 459 |
+
)
|
| 460 |
+
exclude_current = gr.Checkbox(
|
| 461 |
+
value=True,
|
| 462 |
+
label="Exclude current token when sampling"
|
| 463 |
+
)
|
| 464 |
with gr.Row():
|
| 465 |
btn_random = gr.Button("Initialize Random")
|
| 466 |
steps = gr.Slider(1, 2000, value=200, step=1, label="Denoise steps (N)")
|
|
|
|
| 521 |
# Single step
|
| 522 |
btn_step_once.click(
|
| 523 |
step_once,
|
| 524 |
+
[ckpt, ids_state, update_mode, temperature, exclude_current],
|
| 525 |
[ids_state, current_text, status]
|
| 526 |
)
|
| 527 |
|
| 528 |
# Live denoise (streaming)
|
| 529 |
btn_live.click(
|
| 530 |
live_denoise,
|
| 531 |
+
[ckpt, ids_state, steps, snap_every, seed, update_mode, temperature, exclude_current],
|
| 532 |
[ids_state, current_text, status],
|
| 533 |
show_progress=True
|
| 534 |
)
|