Spaces:
Sleeping
Sleeping
Streaming progress bar, rename to Top N, fix scoring
Browse files- Run model.generate() in background thread with BaseStreamer token counter
- Yield progress bar updates during generation (█░ style)
- Rename "Samples (best-of-N)" → "Top N"
- Show per-sample scoring status during annotation phase
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
app.py
CHANGED
|
@@ -52,7 +52,7 @@ from bokeh.resources import CDN as BOKEH_CDN
|
|
| 52 |
from plannotate.annotate import annotate as _plannotate_annotate
|
| 53 |
from plannotate.bokeh_plot import get_bokeh as _plannotate_bokeh
|
| 54 |
from plannotate import resources as _plannotate_rsc
|
| 55 |
-
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 56 |
|
| 57 |
# ---------------------------------------------------------------------------
|
| 58 |
# Configuration
|
|
@@ -500,6 +500,27 @@ def _score_annotation(
|
|
| 500 |
return composite
|
| 501 |
|
| 502 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 503 |
def generate_and_select(
|
| 504 |
prompt_text: str,
|
| 505 |
temperature: float,
|
|
@@ -513,6 +534,7 @@ def generate_and_select(
|
|
| 513 |
|
| 514 |
prompt_text = _ensure_prompt_format(prompt_text)
|
| 515 |
num_samples = max(1, int(num_samples))
|
|
|
|
| 516 |
print(f"[generate] prompt: {prompt_text!r}, n={num_samples}, "
|
| 517 |
f"temp={temperature}, max_tokens={max_tokens}")
|
| 518 |
|
|
@@ -523,25 +545,46 @@ def generate_and_select(
|
|
| 523 |
padding=True,
|
| 524 |
).to(DEVICE)
|
| 525 |
|
| 526 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 527 |
|
| 528 |
t0 = time.time()
|
| 529 |
-
|
| 530 |
-
|
| 531 |
-
|
| 532 |
-
|
| 533 |
-
|
| 534 |
-
|
| 535 |
-
|
| 536 |
-
|
| 537 |
-
|
| 538 |
-
|
| 539 |
-
|
| 540 |
-
print(f"[generate] ERROR: {exc}")
|
| 541 |
-
yield "", f"Generation failed: {exc}", None, None, ""
|
| 542 |
-
return
|
| 543 |
gen_time = time.time() - t0
|
| 544 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 545 |
# Decode all samples
|
| 546 |
samples = []
|
| 547 |
for i in range(outputs.shape[0]):
|
|
@@ -553,7 +596,7 @@ def generate_and_select(
|
|
| 553 |
print(f"[generate] sample {i}: {len(dna)} bp, "
|
| 554 |
f"{'complete' if has_eos else 'truncated'}")
|
| 555 |
|
| 556 |
-
yield "", (f"Generated {
|
| 557 |
"Annotating…"), None, None, ""
|
| 558 |
|
| 559 |
# If only 1 sample, skip scoring
|
|
@@ -571,6 +614,7 @@ def generate_and_select(
|
|
| 571 |
for i, (dna, _) in enumerate(samples):
|
| 572 |
if len(dna) < 100:
|
| 573 |
continue
|
|
|
|
| 574 |
try:
|
| 575 |
hits = _plannotate_annotate(dna, is_detailed=True, linear=False)
|
| 576 |
score = _score_annotation(hits, prompt=prompt_text)
|
|
@@ -585,6 +629,7 @@ def generate_and_select(
|
|
| 585 |
tag = "complete" if has_eos else "max-length"
|
| 586 |
html_map, table, ann_status = _annotate(dna)
|
| 587 |
status = (f"Best of {num_samples}: sample {best_idx+1}, "
|
|
|
|
| 588 |
f"{len(dna)} bp ({tag}, {gen_time:.1f}s). {ann_status}")
|
| 589 |
yield dna, status, html_map, table, ""
|
| 590 |
|
|
@@ -712,7 +757,7 @@ with gr.Blocks(title="PlasmidSpace") as demo:
|
|
| 712 |
0.1, 1.0, value=0.3, step=0.05, label="Temperature",
|
| 713 |
)
|
| 714 |
num_samples = gr.Slider(
|
| 715 |
-
1, 8, value=1, step=1, label="
|
| 716 |
)
|
| 717 |
|
| 718 |
max_tokens = gr.Slider(
|
|
|
|
| 52 |
from plannotate.annotate import annotate as _plannotate_annotate
|
| 53 |
from plannotate.bokeh_plot import get_bokeh as _plannotate_bokeh
|
| 54 |
from plannotate import resources as _plannotate_rsc
|
| 55 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, BaseStreamer
|
| 56 |
|
| 57 |
# ---------------------------------------------------------------------------
|
| 58 |
# Configuration
|
|
|
|
| 500 |
return composite
|
| 501 |
|
| 502 |
|
| 503 |
+
class _TokenCounter(BaseStreamer):
|
| 504 |
+
"""Counts generation steps so the UI can show progress."""
|
| 505 |
+
|
| 506 |
+
def __init__(self):
|
| 507 |
+
self.step = 0
|
| 508 |
+
self.done = False
|
| 509 |
+
|
| 510 |
+
def put(self, value):
|
| 511 |
+
self.step += 1
|
| 512 |
+
|
| 513 |
+
def end(self):
|
| 514 |
+
self.done = True
|
| 515 |
+
|
| 516 |
+
|
| 517 |
+
def _progress_bar(step: int, total: int, width: int = 20) -> str:
|
| 518 |
+
frac = min(step / max(total, 1), 1.0)
|
| 519 |
+
filled = int(width * frac)
|
| 520 |
+
bar = "█" * filled + "░" * (width - filled)
|
| 521 |
+
return f"[{bar}] {step}/{total} tokens ({frac:.0%})"
|
| 522 |
+
|
| 523 |
+
|
| 524 |
def generate_and_select(
|
| 525 |
prompt_text: str,
|
| 526 |
temperature: float,
|
|
|
|
| 534 |
|
| 535 |
prompt_text = _ensure_prompt_format(prompt_text)
|
| 536 |
num_samples = max(1, int(num_samples))
|
| 537 |
+
max_tokens = int(max_tokens)
|
| 538 |
print(f"[generate] prompt: {prompt_text!r}, n={num_samples}, "
|
| 539 |
f"temp={temperature}, max_tokens={max_tokens}")
|
| 540 |
|
|
|
|
| 545 |
padding=True,
|
| 546 |
).to(DEVICE)
|
| 547 |
|
| 548 |
+
# Run generation in background thread with token counter
|
| 549 |
+
counter = _TokenCounter()
|
| 550 |
+
result_holder: list = [None, None] # [outputs, error]
|
| 551 |
+
|
| 552 |
+
def _run_generate():
|
| 553 |
+
try:
|
| 554 |
+
with torch.no_grad():
|
| 555 |
+
result_holder[0] = model.generate(
|
| 556 |
+
**inputs,
|
| 557 |
+
max_new_tokens=max_tokens,
|
| 558 |
+
temperature=float(temperature),
|
| 559 |
+
do_sample=True,
|
| 560 |
+
top_k=50,
|
| 561 |
+
use_cache=True,
|
| 562 |
+
streamer=counter,
|
| 563 |
+
)
|
| 564 |
+
except Exception as exc:
|
| 565 |
+
result_holder[1] = exc
|
| 566 |
|
| 567 |
t0 = time.time()
|
| 568 |
+
gen_thread = threading.Thread(target=_run_generate)
|
| 569 |
+
gen_thread.start()
|
| 570 |
+
|
| 571 |
+
# Poll progress and yield status updates
|
| 572 |
+
n_label = f"{num_samples} sample(s)" if num_samples > 1 else "1 sample"
|
| 573 |
+
while gen_thread.is_alive():
|
| 574 |
+
elapsed = time.time() - t0
|
| 575 |
+
bar = _progress_bar(counter.step, max_tokens)
|
| 576 |
+
yield "", f"Generating {n_label}… {bar} ({elapsed:.1f}s)", None, None, ""
|
| 577 |
+
gen_thread.join(timeout=0.4)
|
| 578 |
+
|
|
|
|
|
|
|
|
|
|
| 579 |
gen_time = time.time() - t0
|
| 580 |
|
| 581 |
+
if result_holder[1] is not None:
|
| 582 |
+
print(f"[generate] ERROR: {result_holder[1]}")
|
| 583 |
+
yield "", f"Generation failed: {result_holder[1]}", None, None, ""
|
| 584 |
+
return
|
| 585 |
+
|
| 586 |
+
outputs = result_holder[0]
|
| 587 |
+
|
| 588 |
# Decode all samples
|
| 589 |
samples = []
|
| 590 |
for i in range(outputs.shape[0]):
|
|
|
|
| 596 |
print(f"[generate] sample {i}: {len(dna)} bp, "
|
| 597 |
f"{'complete' if has_eos else 'truncated'}")
|
| 598 |
|
| 599 |
+
yield "", (f"Generated {n_label} ({counter.step} tokens) in {gen_time:.1f}s. "
|
| 600 |
"Annotating…"), None, None, ""
|
| 601 |
|
| 602 |
# If only 1 sample, skip scoring
|
|
|
|
| 614 |
for i, (dna, _) in enumerate(samples):
|
| 615 |
if len(dna) < 100:
|
| 616 |
continue
|
| 617 |
+
yield "", (f"Scoring sample {i+1}/{num_samples}…"), None, None, ""
|
| 618 |
try:
|
| 619 |
hits = _plannotate_annotate(dna, is_detailed=True, linear=False)
|
| 620 |
score = _score_annotation(hits, prompt=prompt_text)
|
|
|
|
| 629 |
tag = "complete" if has_eos else "max-length"
|
| 630 |
html_map, table, ann_status = _annotate(dna)
|
| 631 |
status = (f"Best of {num_samples}: sample {best_idx+1}, "
|
| 632 |
+
f"score={best_score:.2f}, "
|
| 633 |
f"{len(dna)} bp ({tag}, {gen_time:.1f}s). {ann_status}")
|
| 634 |
yield dna, status, html_map, table, ""
|
| 635 |
|
|
|
|
| 757 |
0.1, 1.0, value=0.3, step=0.05, label="Temperature",
|
| 758 |
)
|
| 759 |
num_samples = gr.Slider(
|
| 760 |
+
1, 8, value=1, step=1, label="Top N",
|
| 761 |
)
|
| 762 |
|
| 763 |
max_tokens = gr.Slider(
|