Spaces:
Sleeping
Sleeping
the-puzzler
commited on
Commit
·
22b6693
1
Parent(s):
6fb6b5d
minor
Browse files
app.py
CHANGED
|
@@ -119,7 +119,7 @@ class CNA(nn.Module):
|
|
| 119 |
return self.proj(h)
|
| 120 |
|
| 121 |
# -----------------------------
|
| 122 |
-
# Helpers
|
| 123 |
# -----------------------------
|
| 124 |
def infer_expansion_factor_from_state(state, embed_dim):
|
| 125 |
for key in ("blocks.0.mlp.0.weight", "blocks.0.mlp.2.weight"):
|
|
@@ -132,18 +132,84 @@ def infer_expansion_factor_from_state(state, embed_dim):
|
|
| 132 |
return 4
|
| 133 |
|
| 134 |
@torch.no_grad()
|
| 135 |
-
def decode(ids, tokenizer, max_chars=
|
| 136 |
s = tokenizer.decode(ids.tolist(), skip_special_tokens=True)
|
| 137 |
s = s.replace("\n", " ")
|
| 138 |
return s[:max_chars] + ("…" if len(s) > max_chars else "")
|
| 139 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 140 |
@torch.no_grad()
|
| 141 |
def model_logits(model, x):
|
| 142 |
return model(x)
|
| 143 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 144 |
# -----------------------------
|
| 145 |
# Load checkpoint & build model
|
| 146 |
# -----------------------------
|
|
|
|
|
|
|
|
|
|
| 147 |
def load_model(ckpt_path: str):
|
| 148 |
if not os.path.exists(ckpt_path):
|
| 149 |
raise FileNotFoundError(
|
|
@@ -196,68 +262,204 @@ def load_model(ckpt_path: str):
|
|
| 196 |
model.eval()
|
| 197 |
return model, tokenizer, int(radius)
|
| 198 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 199 |
# -----------------------------
|
| 200 |
-
#
|
| 201 |
# -----------------------------
|
| 202 |
@torch.no_grad()
|
| 203 |
-
def
|
| 204 |
-
random
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
logits_pos = model_logits(model, x)[0, pos] # [V]
|
| 211 |
-
x[0, pos] = int(torch.argmax(logits_pos).item())
|
| 212 |
-
if (t % snap_every == 0) or (t == steps):
|
| 213 |
-
snaps.append((t, decode(x[0].cpu(), tokenizer, max_chars)))
|
| 214 |
-
return snaps
|
| 215 |
|
| 216 |
# -----------------------------
|
| 217 |
-
# Gradio
|
| 218 |
# -----------------------------
|
| 219 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 220 |
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
|
|
|
| 226 |
|
| 227 |
-
def
|
| 228 |
ensure_model(ckpt_path or DEFAULT_CKPT)
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 238 |
|
| 239 |
-
|
|
|
|
|
|
|
|
|
|
| 240 |
gr.Markdown(
|
| 241 |
"""
|
| 242 |
-
# CNA —
|
| 243 |
-
|
| 244 |
-
-
|
|
|
|
|
|
|
| 245 |
"""
|
| 246 |
)
|
|
|
|
|
|
|
| 247 |
with gr.Row():
|
| 248 |
-
ckpt = gr.Textbox(value=DEFAULT_CKPT, label="Checkpoint path"
|
| 249 |
-
with gr.Row():
|
| 250 |
seqlen = gr.Slider(10, 512, value=100, step=1, label="Sequence length (S)")
|
| 251 |
-
|
| 252 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 253 |
with gr.Row():
|
| 254 |
-
|
| 255 |
-
max_chars = gr.Slider(32, 1000, value=220, step=1, label="Max chars per snapshot")
|
| 256 |
-
run_btn = gr.Button("Run")
|
| 257 |
with gr.Row():
|
| 258 |
-
|
| 259 |
-
|
| 260 |
|
| 261 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 262 |
|
| 263 |
demo.queue(concurrency_count=1).launch()
|
|
|
|
| 119 |
return self.proj(h)
|
| 120 |
|
| 121 |
# -----------------------------
|
| 122 |
+
# Helpers
|
| 123 |
# -----------------------------
|
| 124 |
def infer_expansion_factor_from_state(state, embed_dim):
|
| 125 |
for key in ("blocks.0.mlp.0.weight", "blocks.0.mlp.2.weight"):
|
|
|
|
| 132 |
return 4
|
| 133 |
|
| 134 |
@torch.no_grad()
|
| 135 |
+
def decode(ids, tokenizer, max_chars=1000):
|
| 136 |
s = tokenizer.decode(ids.tolist(), skip_special_tokens=True)
|
| 137 |
s = s.replace("\n", " ")
|
| 138 |
return s[:max_chars] + ("…" if len(s) > max_chars else "")
|
| 139 |
|
| 140 |
+
def to_fixed_len_ids(text, tokenizer, seqlen, pad_mode="random", rnd=None):
|
| 141 |
+
"""Encode text and force to length seqlen."""
|
| 142 |
+
if rnd is None:
|
| 143 |
+
rnd = random.Random()
|
| 144 |
+
ids = tokenizer.encode(text, add_special_tokens=False)
|
| 145 |
+
V = tokenizer.vocab_size
|
| 146 |
+
if len(ids) >= seqlen:
|
| 147 |
+
ids = ids[:seqlen]
|
| 148 |
+
else:
|
| 149 |
+
need = seqlen - len(ids)
|
| 150 |
+
if pad_mode == "eos" and tokenizer.eos_token_id is not None:
|
| 151 |
+
ids = ids + [tokenizer.eos_token_id] * need
|
| 152 |
+
else:
|
| 153 |
+
ids = ids + [rnd.randrange(V) for _ in range(need)]
|
| 154 |
+
return torch.tensor(ids, dtype=torch.long).unsqueeze(0)
|
| 155 |
+
|
| 156 |
@torch.no_grad()
|
| 157 |
def model_logits(model, x):
|
| 158 |
return model(x)
|
| 159 |
|
| 160 |
+
def apply_noise_ops(x, tokenizer, indices_csv, add_noise_left, add_noise_right, seqlen, seed=0):
|
| 161 |
+
"""Noise selected positions and optionally prepend/append random tokens."""
|
| 162 |
+
rnd = random.Random(seed)
|
| 163 |
+
V = tokenizer.vocab_size
|
| 164 |
+
x = x.clone()
|
| 165 |
+
|
| 166 |
+
# noise brush (indices like "0, 5, 6-10")
|
| 167 |
+
idxs = set()
|
| 168 |
+
if indices_csv.strip():
|
| 169 |
+
for part in indices_csv.split(","):
|
| 170 |
+
part = part.strip()
|
| 171 |
+
if not part:
|
| 172 |
+
continue
|
| 173 |
+
if "-" in part:
|
| 174 |
+
a, b = part.split("-", 1)
|
| 175 |
+
try:
|
| 176 |
+
a, b = int(a), int(b)
|
| 177 |
+
for j in range(min(a,b), max(a,b)+1):
|
| 178 |
+
idxs.add(j)
|
| 179 |
+
except:
|
| 180 |
+
continue
|
| 181 |
+
else:
|
| 182 |
+
try:
|
| 183 |
+
idxs.add(int(part))
|
| 184 |
+
except:
|
| 185 |
+
continue
|
| 186 |
+
for j in idxs:
|
| 187 |
+
if 0 <= j < seqlen:
|
| 188 |
+
x[0, j] = rnd.randrange(V)
|
| 189 |
+
|
| 190 |
+
# prepend/append random noise
|
| 191 |
+
if add_noise_left > 0:
|
| 192 |
+
prefix = torch.tensor([rnd.randrange(V) for _ in range(add_noise_left)], dtype=torch.long).unsqueeze(0)
|
| 193 |
+
x = torch.cat([prefix, x], dim=1)
|
| 194 |
+
if add_noise_right > 0:
|
| 195 |
+
suffix = torch.tensor([rnd.randrange(V) for _ in range(add_noise_right)], dtype=torch.long).unsqueeze(0)
|
| 196 |
+
x = torch.cat([x, suffix], dim=1)
|
| 197 |
+
|
| 198 |
+
# force length back to seqlen (trim or pad random)
|
| 199 |
+
if x.shape[1] > seqlen:
|
| 200 |
+
x = x[:, :seqlen]
|
| 201 |
+
elif x.shape[1] < seqlen:
|
| 202 |
+
need = seqlen - x.shape[1]
|
| 203 |
+
pad = torch.tensor([rnd.randrange(V) for _ in range(need)], dtype=torch.long).unsqueeze(0)
|
| 204 |
+
x = torch.cat([x, pad], dim=1)
|
| 205 |
+
return x
|
| 206 |
+
|
| 207 |
# -----------------------------
|
| 208 |
# Load checkpoint & build model
|
| 209 |
# -----------------------------
|
| 210 |
+
DEFAULT_CKPT = os.environ.get("CKPT_PATH", "ckpt_latest.pt")
|
| 211 |
+
model_cache = {"model": None, "tokenizer": None, "radius": None, "ckpt": None}
|
| 212 |
+
|
| 213 |
def load_model(ckpt_path: str):
|
| 214 |
if not os.path.exists(ckpt_path):
|
| 215 |
raise FileNotFoundError(
|
|
|
|
| 262 |
model.eval()
|
| 263 |
return model, tokenizer, int(radius)
|
| 264 |
|
| 265 |
+
def ensure_model(ckpt_path):
|
| 266 |
+
if model_cache["model"] is None or model_cache["ckpt"] != ckpt_path:
|
| 267 |
+
m, tok, rad = load_model(ckpt_path)
|
| 268 |
+
model_cache.update({"model": m, "tokenizer": tok, "radius": rad, "ckpt": ckpt_path})
|
| 269 |
+
|
| 270 |
# -----------------------------
|
| 271 |
+
# Strategy 1 core step
|
| 272 |
# -----------------------------
|
| 273 |
@torch.no_grad()
|
| 274 |
+
def step_strategy1(model, x):
|
| 275 |
+
"""One iteration: choose random position, set to argmax(logits)."""
|
| 276 |
+
S = x.shape[1]
|
| 277 |
+
pos = int(torch.randint(0, S, (1,)).item())
|
| 278 |
+
logits_pos = model_logits(model, x)[0, pos] # [V]
|
| 279 |
+
x[0, pos] = int(torch.argmax(logits_pos).item())
|
| 280 |
+
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 281 |
|
| 282 |
# -----------------------------
|
| 283 |
+
# Gradio logic
|
| 284 |
# -----------------------------
|
| 285 |
+
def init_random(ckpt_path, seqlen, seed):
|
| 286 |
+
ensure_model(ckpt_path or DEFAULT_CKPT)
|
| 287 |
+
random.seed(seed); torch.manual_seed(seed)
|
| 288 |
+
V = model_cache["tokenizer"].vocab_size
|
| 289 |
+
x = torch.randint(0, V, (1, seqlen))
|
| 290 |
+
txt = decode(x[0], model_cache["tokenizer"])
|
| 291 |
+
return x.tolist(), txt, f"Initialized random sequence (len={seqlen})"
|
| 292 |
|
| 293 |
+
def init_from_text(ckpt_path, seqlen, text, seed, pad_mode):
|
| 294 |
+
ensure_model(ckpt_path or DEFAULT_CKPT)
|
| 295 |
+
rnd = random.Random(seed)
|
| 296 |
+
x = to_fixed_len_ids(text or "", model_cache["tokenizer"], seqlen, pad_mode=pad_mode, rnd=rnd)
|
| 297 |
+
txt = decode(x[0], model_cache["tokenizer"])
|
| 298 |
+
return x.tolist(), txt, "Initialized from text"
|
| 299 |
|
| 300 |
+
def append_text(ckpt_path, state_ids, seqlen, text_to_append, seed):
|
| 301 |
ensure_model(ckpt_path or DEFAULT_CKPT)
|
| 302 |
+
tok = model_cache["tokenizer"]
|
| 303 |
+
rnd = random.Random(seed)
|
| 304 |
+
if state_ids is None or len(state_ids) == 0:
|
| 305 |
+
x = to_fixed_len_ids(text_to_append or "", tok, seqlen, pad_mode="random", rnd=rnd)
|
| 306 |
+
else:
|
| 307 |
+
x = torch.tensor(state_ids, dtype=torch.long).unsqueeze(0)
|
| 308 |
+
# append
|
| 309 |
+
extra = tok.encode(text_to_append or "", add_special_tokens=False)
|
| 310 |
+
x = torch.cat([x, torch.tensor(extra, dtype=torch.long).unsqueeze(0)], dim=1)
|
| 311 |
+
# force length
|
| 312 |
+
if x.shape[1] > seqlen:
|
| 313 |
+
x = x[:, :seqlen]
|
| 314 |
+
elif x.shape[1] < seqlen:
|
| 315 |
+
need = seqlen - x.shape[1]
|
| 316 |
+
V = tok.vocab_size
|
| 317 |
+
pad = torch.tensor([rnd.randrange(V) for _ in range(need)], dtype=torch.long).unsqueeze(0)
|
| 318 |
+
x = torch.cat([x, pad], dim=1)
|
| 319 |
+
txt = decode(x[0], tok)
|
| 320 |
+
return x.tolist(), txt, "Appended text and resized to target length"
|
| 321 |
+
|
| 322 |
+
def apply_noise(ckpt_path, state_ids, seqlen, indices_csv, add_left, add_right, seed):
|
| 323 |
+
ensure_model(ckpt_path or DEFAULT_CKPT)
|
| 324 |
+
tok = model_cache["tokenizer"]
|
| 325 |
+
if state_ids is None or len(state_ids) == 0:
|
| 326 |
+
# create an empty base (random) then apply ops
|
| 327 |
+
V = tok.vocab_size
|
| 328 |
+
base = torch.randint(0, V, (1, seqlen))
|
| 329 |
+
else:
|
| 330 |
+
base = torch.tensor(state_ids, dtype=torch.long).unsqueeze(0)
|
| 331 |
+
x = apply_noise_ops(base, tok, indices_csv, int(add_left), int(add_right), seqlen, seed=seed)
|
| 332 |
+
txt = decode(x[0], tok)
|
| 333 |
+
return x.tolist(), txt, "Applied noise brush / prepend / append"
|
| 334 |
+
|
| 335 |
+
def step_once(ckpt_path, state_ids):
|
| 336 |
+
ensure_model(ckpt_path or DEFAULT_CKPT)
|
| 337 |
+
tok = model_cache["tokenizer"]
|
| 338 |
+
if state_ids is None or len(state_ids) == 0:
|
| 339 |
+
return None, "", "No sequence to step — initialize first."
|
| 340 |
+
x = torch.tensor(state_ids, dtype=torch.long).unsqueeze(0)
|
| 341 |
+
x = step_strategy1(model_cache["model"], x)
|
| 342 |
+
txt = decode(x[0], tok)
|
| 343 |
+
return x.tolist(), txt, "Stepped 1 iteration"
|
| 344 |
+
|
| 345 |
+
def live_denoise(ckpt_path, state_ids, steps, snap_every, seed):
|
| 346 |
+
"""
|
| 347 |
+
Generator for live updates. Yields (ids, text, status) every snap_every steps and on completion.
|
| 348 |
+
"""
|
| 349 |
+
ensure_model(ckpt_path or DEFAULT_CKPT)
|
| 350 |
+
tok = model_cache["tokenizer"]
|
| 351 |
+
if state_ids is None or len(state_ids) == 0:
|
| 352 |
+
return
|
| 353 |
+
random.seed(seed); torch.manual_seed(seed)
|
| 354 |
+
x = torch.tensor(state_ids, dtype=torch.long).unsqueeze(0)
|
| 355 |
+
total = int(steps)
|
| 356 |
+
snap = max(1, int(snap_every))
|
| 357 |
+
for t in range(1, total + 1):
|
| 358 |
+
x = step_strategy1(model_cache["model"], x)
|
| 359 |
+
if (t % snap == 0) or (t == total):
|
| 360 |
+
txt = decode(x[0], tok)
|
| 361 |
+
yield x.tolist(), txt, f"Live denoise… step {t}/{total}"
|
| 362 |
+
# final yield already done in loop
|
| 363 |
|
| 364 |
+
# -----------------------------
|
| 365 |
+
# UI
|
| 366 |
+
# -----------------------------
|
| 367 |
+
with gr.Blocks(title="CNA — Interactive Denoising (Strategy 1)") as demo:
|
| 368 |
gr.Markdown(
|
| 369 |
"""
|
| 370 |
+
# CNA — Interactive Denoising (Strategy 1)
|
| 371 |
+
- **Mode 1:** Randomize then watch it **denoise live** (random-position → argmax).
|
| 372 |
+
- **Mode 2:** Initialize from **your text**.
|
| 373 |
+
- **Noise Brush:** Select positions (e.g., `0, 5, 10-20`), and/or add random noise tokens at **start**/**end**.
|
| 374 |
+
- **Append:** Add your text to the current sequence.
|
| 375 |
"""
|
| 376 |
)
|
| 377 |
+
|
| 378 |
+
# Global settings
|
| 379 |
with gr.Row():
|
| 380 |
+
ckpt = gr.Textbox(value=DEFAULT_CKPT, label="Checkpoint path")
|
|
|
|
| 381 |
seqlen = gr.Slider(10, 512, value=100, step=1, label="Sequence length (S)")
|
| 382 |
+
seed = gr.Slider(0, 10000, value=0, step=1, label="Seed")
|
| 383 |
+
|
| 384 |
+
# Hidden state (ids list)
|
| 385 |
+
ids_state = gr.State(value=None)
|
| 386 |
+
|
| 387 |
+
# Displays
|
| 388 |
+
with gr.Row():
|
| 389 |
+
current_text = gr.Textbox(lines=8, label="Current text", interactive=False)
|
| 390 |
+
status = gr.Markdown("Ready.")
|
| 391 |
+
|
| 392 |
+
gr.Markdown("## Mode 1 · Random → Denoise Live")
|
| 393 |
+
with gr.Row():
|
| 394 |
+
btn_random = gr.Button("Initialize Random")
|
| 395 |
+
steps = gr.Slider(1, 2000, value=200, step=1, label="Denoise steps (N)")
|
| 396 |
+
snap_every = gr.Slider(1, 100, value=5, step=1, label="Update every K steps")
|
| 397 |
+
with gr.Row():
|
| 398 |
+
btn_step_once = gr.Button("Step Once")
|
| 399 |
+
btn_live = gr.Button("Denoise Live (streaming)")
|
| 400 |
+
|
| 401 |
+
gr.Markdown("## Mode 2 · Initialize From Your Text")
|
| 402 |
with gr.Row():
|
| 403 |
+
init_text = gr.Textbox(lines=4, label="Initial text")
|
|
|
|
|
|
|
| 404 |
with gr.Row():
|
| 405 |
+
pad_mode = gr.Radio(choices=["random", "eos"], value="random", label="Pad mode (if text shorter than S)")
|
| 406 |
+
btn_init_text = gr.Button("Initialize From Text")
|
| 407 |
|
| 408 |
+
gr.Markdown("## Noise Brush · Select Positions + Prepend/Append Noise")
|
| 409 |
+
with gr.Row():
|
| 410 |
+
indices_csv = gr.Textbox(label="Positions to noise (e.g., 0, 5, 10-20)", placeholder="Leave empty to skip")
|
| 411 |
+
with gr.Row():
|
| 412 |
+
add_left = gr.Number(value=0, precision=0, label="Noise tokens to add at START")
|
| 413 |
+
add_right = gr.Number(value=0, precision=0, label="Noise tokens to add at END")
|
| 414 |
+
btn_apply_noise = gr.Button("Apply Noise Brush / Prepend / Append")
|
| 415 |
+
|
| 416 |
+
gr.Markdown("## Append Text")
|
| 417 |
+
with gr.Row():
|
| 418 |
+
append_box = gr.Textbox(lines=3, label="Text to append")
|
| 419 |
+
btn_append = gr.Button("Append to Current Sequence")
|
| 420 |
+
|
| 421 |
+
# --- Wiring ---
|
| 422 |
+
# Random init
|
| 423 |
+
out = btn_random.click(
|
| 424 |
+
init_random,
|
| 425 |
+
[ckpt, seqlen, seed],
|
| 426 |
+
[ids_state, current_text, status]
|
| 427 |
+
)
|
| 428 |
+
|
| 429 |
+
# Init from text
|
| 430 |
+
btn_init_text.click(
|
| 431 |
+
init_from_text,
|
| 432 |
+
[ckpt, seqlen, init_text, seed, pad_mode],
|
| 433 |
+
[ids_state, current_text, status]
|
| 434 |
+
)
|
| 435 |
+
|
| 436 |
+
# Apply noise
|
| 437 |
+
btn_apply_noise.click(
|
| 438 |
+
apply_noise,
|
| 439 |
+
[ckpt, ids_state, seqlen, indices_csv, add_left, add_right, seed],
|
| 440 |
+
[ids_state, current_text, status]
|
| 441 |
+
)
|
| 442 |
+
|
| 443 |
+
# Append text
|
| 444 |
+
btn_append.click(
|
| 445 |
+
append_text,
|
| 446 |
+
[ckpt, ids_state, seqlen, append_box, seed],
|
| 447 |
+
[ids_state, current_text, status]
|
| 448 |
+
)
|
| 449 |
+
|
| 450 |
+
# Single step
|
| 451 |
+
btn_step_once.click(
|
| 452 |
+
step_once,
|
| 453 |
+
[ckpt, ids_state],
|
| 454 |
+
[ids_state, current_text, status]
|
| 455 |
+
)
|
| 456 |
+
|
| 457 |
+
# Live denoise (streaming)
|
| 458 |
+
btn_live.click(
|
| 459 |
+
live_denoise,
|
| 460 |
+
[ckpt, ids_state, steps, snap_every, seed],
|
| 461 |
+
[ids_state, current_text, status],
|
| 462 |
+
show_progress=True
|
| 463 |
+
)
|
| 464 |
|
| 465 |
demo.queue(concurrency_count=1).launch()
|