McClain Claude Opus 4.6 commited on
Commit
74ad07d
·
1 Parent(s): 5d0ce94

Streaming progress bar, rename to Top N, fix scoring

Browse files

- Run model.generate() in background thread with BaseStreamer token counter
- Yield progress bar updates during generation (█░ style)
- Rename "Samples (best-of-N)" → "Top N"
- Show per-sample scoring status during annotation phase

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

Files changed (1) hide show
  1. app.py +63 -18
app.py CHANGED
@@ -52,7 +52,7 @@ from bokeh.resources import CDN as BOKEH_CDN
52
  from plannotate.annotate import annotate as _plannotate_annotate
53
  from plannotate.bokeh_plot import get_bokeh as _plannotate_bokeh
54
  from plannotate import resources as _plannotate_rsc
55
- from transformers import AutoModelForCausalLM, AutoTokenizer
56
 
57
  # ---------------------------------------------------------------------------
58
  # Configuration
@@ -500,6 +500,27 @@ def _score_annotation(
500
  return composite
501
 
502
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
503
  def generate_and_select(
504
  prompt_text: str,
505
  temperature: float,
@@ -513,6 +534,7 @@ def generate_and_select(
513
 
514
  prompt_text = _ensure_prompt_format(prompt_text)
515
  num_samples = max(1, int(num_samples))
 
516
  print(f"[generate] prompt: {prompt_text!r}, n={num_samples}, "
517
  f"temp={temperature}, max_tokens={max_tokens}")
518
 
@@ -523,25 +545,46 @@ def generate_and_select(
523
  padding=True,
524
  ).to(DEVICE)
525
 
526
- yield "", f"Generating {num_samples} sample(s)…", None, None, ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
527
 
528
  t0 = time.time()
529
- try:
530
- with torch.no_grad():
531
- outputs = model.generate(
532
- **inputs,
533
- max_new_tokens=int(max_tokens),
534
- temperature=float(temperature),
535
- do_sample=True,
536
- top_k=50,
537
- use_cache=True,
538
- )
539
- except Exception as exc:
540
- print(f"[generate] ERROR: {exc}")
541
- yield "", f"Generation failed: {exc}", None, None, ""
542
- return
543
  gen_time = time.time() - t0
544
 
 
 
 
 
 
 
 
545
  # Decode all samples
546
  samples = []
547
  for i in range(outputs.shape[0]):
@@ -553,7 +596,7 @@ def generate_and_select(
553
  print(f"[generate] sample {i}: {len(dna)} bp, "
554
  f"{'complete' if has_eos else 'truncated'}")
555
 
556
- yield "", (f"Generated {num_samples} sample(s) in {gen_time:.1f}s. "
557
  "Annotating…"), None, None, ""
558
 
559
  # If only 1 sample, skip scoring
@@ -571,6 +614,7 @@ def generate_and_select(
571
  for i, (dna, _) in enumerate(samples):
572
  if len(dna) < 100:
573
  continue
 
574
  try:
575
  hits = _plannotate_annotate(dna, is_detailed=True, linear=False)
576
  score = _score_annotation(hits, prompt=prompt_text)
@@ -585,6 +629,7 @@ def generate_and_select(
585
  tag = "complete" if has_eos else "max-length"
586
  html_map, table, ann_status = _annotate(dna)
587
  status = (f"Best of {num_samples}: sample {best_idx+1}, "
 
588
  f"{len(dna)} bp ({tag}, {gen_time:.1f}s). {ann_status}")
589
  yield dna, status, html_map, table, ""
590
 
@@ -712,7 +757,7 @@ with gr.Blocks(title="PlasmidSpace") as demo:
712
  0.1, 1.0, value=0.3, step=0.05, label="Temperature",
713
  )
714
  num_samples = gr.Slider(
715
- 1, 8, value=1, step=1, label="Samples (best-of-N)",
716
  )
717
 
718
  max_tokens = gr.Slider(
 
52
  from plannotate.annotate import annotate as _plannotate_annotate
53
  from plannotate.bokeh_plot import get_bokeh as _plannotate_bokeh
54
  from plannotate import resources as _plannotate_rsc
55
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BaseStreamer
56
 
57
  # ---------------------------------------------------------------------------
58
  # Configuration
 
500
  return composite
501
 
502
 
503
+ class _TokenCounter(BaseStreamer):
504
+ """Counts generation steps so the UI can show progress."""
505
+
506
+ def __init__(self):
507
+ self.step = 0
508
+ self.done = False
509
+
510
+ def put(self, value):
511
+ self.step += 1
512
+
513
+ def end(self):
514
+ self.done = True
515
+
516
+
517
+ def _progress_bar(step: int, total: int, width: int = 20) -> str:
518
+ frac = min(step / max(total, 1), 1.0)
519
+ filled = int(width * frac)
520
+ bar = "█" * filled + "░" * (width - filled)
521
+ return f"[{bar}] {step}/{total} tokens ({frac:.0%})"
522
+
523
+
524
  def generate_and_select(
525
  prompt_text: str,
526
  temperature: float,
 
534
 
535
  prompt_text = _ensure_prompt_format(prompt_text)
536
  num_samples = max(1, int(num_samples))
537
+ max_tokens = int(max_tokens)
538
  print(f"[generate] prompt: {prompt_text!r}, n={num_samples}, "
539
  f"temp={temperature}, max_tokens={max_tokens}")
540
 
 
545
  padding=True,
546
  ).to(DEVICE)
547
 
548
+ # Run generation in background thread with token counter
549
+ counter = _TokenCounter()
550
+ result_holder: list = [None, None] # [outputs, error]
551
+
552
+ def _run_generate():
553
+ try:
554
+ with torch.no_grad():
555
+ result_holder[0] = model.generate(
556
+ **inputs,
557
+ max_new_tokens=max_tokens,
558
+ temperature=float(temperature),
559
+ do_sample=True,
560
+ top_k=50,
561
+ use_cache=True,
562
+ streamer=counter,
563
+ )
564
+ except Exception as exc:
565
+ result_holder[1] = exc
566
 
567
  t0 = time.time()
568
+ gen_thread = threading.Thread(target=_run_generate)
569
+ gen_thread.start()
570
+
571
+ # Poll progress and yield status updates
572
+ n_label = f"{num_samples} sample(s)" if num_samples > 1 else "1 sample"
573
+ while gen_thread.is_alive():
574
+ elapsed = time.time() - t0
575
+ bar = _progress_bar(counter.step, max_tokens)
576
+ yield "", f"Generating {n_label}… {bar} ({elapsed:.1f}s)", None, None, ""
577
+ gen_thread.join(timeout=0.4)
578
+
 
 
 
579
  gen_time = time.time() - t0
580
 
581
+ if result_holder[1] is not None:
582
+ print(f"[generate] ERROR: {result_holder[1]}")
583
+ yield "", f"Generation failed: {result_holder[1]}", None, None, ""
584
+ return
585
+
586
+ outputs = result_holder[0]
587
+
588
  # Decode all samples
589
  samples = []
590
  for i in range(outputs.shape[0]):
 
596
  print(f"[generate] sample {i}: {len(dna)} bp, "
597
  f"{'complete' if has_eos else 'truncated'}")
598
 
599
+ yield "", (f"Generated {n_label} ({counter.step} tokens) in {gen_time:.1f}s. "
600
  "Annotating…"), None, None, ""
601
 
602
  # If only 1 sample, skip scoring
 
614
  for i, (dna, _) in enumerate(samples):
615
  if len(dna) < 100:
616
  continue
617
+ yield "", (f"Scoring sample {i+1}/{num_samples}…"), None, None, ""
618
  try:
619
  hits = _plannotate_annotate(dna, is_detailed=True, linear=False)
620
  score = _score_annotation(hits, prompt=prompt_text)
 
629
  tag = "complete" if has_eos else "max-length"
630
  html_map, table, ann_status = _annotate(dna)
631
  status = (f"Best of {num_samples}: sample {best_idx+1}, "
632
+ f"score={best_score:.2f}, "
633
  f"{len(dna)} bp ({tag}, {gen_time:.1f}s). {ann_status}")
634
  yield dna, status, html_map, table, ""
635
 
 
757
  0.1, 1.0, value=0.3, step=0.05, label="Temperature",
758
  )
759
  num_samples = gr.Slider(
760
+ 1, 8, value=1, step=1, label="Top N",
761
  )
762
 
763
  max_tokens = gr.Slider(