bhsinghgrid commited on
Commit
a3ec6c4
·
verified ·
1 Parent(s): 5fd6ec8

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +71 -5
app.py CHANGED
@@ -461,13 +461,33 @@ def _live_task_analysis(model_bundle, task: str, input_text: str) -> str:
461
  )
462
 
463
  if str(task) == "2":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
464
  tfidf = _mini_tfidf_scores(text)
465
  top = sorted(tfidf.items(), key=lambda kv: kv[1], reverse=True)[:5]
 
466
  return (
467
  f"[Live Task2]\n"
468
  f"Input: {text}\nPrediction: {pred}\n"
469
  f"Token-length={len(toks)} unique-ratio={uniq:.3f}\n"
470
- f"TF-IDF(top): {top}"
 
471
  )
472
 
473
  if str(task) == "3":
@@ -565,7 +585,7 @@ def start_run_all_background(model_bundle, output_dir, input_text, task4_phase):
565
 
566
  def poll_run_all_background(job_id, output_dir):
567
  if not job_id or job_id not in _BG_JOBS:
568
- msg = "No active background job. Start Run All 5 Tasks first."
569
  empty = refresh_task_outputs(output_dir)
570
  return msg, msg, *empty
571
  j = _BG_JOBS[job_id]
@@ -577,6 +597,12 @@ def poll_run_all_background(job_id, output_dir):
577
  return status, j.get("log", ""), *outputs
578
 
579
 
 
 
 
 
 
 
580
  def run_single_task(model_bundle, task, output_dir, input_text, task4_phase):
581
  if not model_bundle:
582
  raise gr.Error("Load a model first.")
@@ -641,11 +667,29 @@ def refresh_task_outputs(output_dir):
641
 
642
  task2_drift = _img_or_none(os.path.join(output_dir, "task2_semantic_drift.png"))
643
  task2_attn = _img_or_none(os.path.join(output_dir, "task2_attn_t0.png"))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
644
  task3_space = _img_or_none(os.path.join(output_dir, "task3_concept_space.png"))
645
  task4_plot = _img_or_none(os.path.join(output_dir, "task4_ablation_3d.png"))
646
  if task4_plot is None:
647
  task4_plot = _img_or_none(os.path.join(output_dir, "task4_3d.png"))
648
- return task1_txt, task2_txt, task2_drift, task2_attn, task3_txt, task3_space, task5_txt, task4_plot
 
 
 
649
 
650
 
651
  CUSTOM_CSS = """
@@ -745,6 +789,9 @@ with gr.Blocks(title="Sanskrit Diffusion Client Demo", css=CUSTOM_CSS) as demo:
745
  with gr.Row():
746
  task2_drift_img = gr.Image(label="Task2 Drift", type="filepath")
747
  task2_attn_img = gr.Image(label="Task2 Attention", type="filepath")
 
 
 
748
  task3_box = gr.Textbox(label="Task 3 Report", lines=10, interactive=False)
749
  task3_img = gr.Image(label="Task3 Concept Space", type="filepath")
750
  task5_box = gr.Textbox(label="Task 5 Report", lines=10, interactive=False)
@@ -839,9 +886,22 @@ with gr.Blocks(title="Sanskrit Diffusion Client Demo", css=CUSTOM_CSS) as demo:
839
  )
840
 
841
  run_single_btn.click(
842
- fn=run_single_task,
843
  inputs=[model_state, task_choice, analysis_output_dir, analysis_input, task4_phase],
844
- outputs=[task_run_status, task_run_log],
 
 
 
 
 
 
 
 
 
 
 
 
 
845
  )
846
  run_all_btn.click(
847
  fn=start_run_all_background,
@@ -858,6 +918,8 @@ with gr.Blocks(title="Sanskrit Diffusion Client Demo", css=CUSTOM_CSS) as demo:
858
  task2_box,
859
  task2_drift_img,
860
  task2_attn_img,
 
 
861
  task3_box,
862
  task3_img,
863
  task5_box,
@@ -872,6 +934,8 @@ with gr.Blocks(title="Sanskrit Diffusion Client Demo", css=CUSTOM_CSS) as demo:
872
  task2_box,
873
  task2_drift_img,
874
  task2_attn_img,
 
 
875
  task3_box,
876
  task3_img,
877
  task5_box,
@@ -890,6 +954,8 @@ with gr.Blocks(title="Sanskrit Diffusion Client Demo", css=CUSTOM_CSS) as demo:
890
  task2_box,
891
  task2_drift_img,
892
  task2_attn_img,
 
 
893
  task3_box,
894
  task3_img,
895
  task5_box,
 
461
  )
462
 
463
  if str(task) == "2":
464
+ # Live diffusion proxy: run same input with multiple step counts and
465
+ # show semantic drift to final output while task is running.
466
+ base_steps = int(model_bundle["cfg"]["inference"].get("num_steps", 64))
467
+ step_grid = sorted(set([max(1, base_steps), max(1, base_steps // 2), max(1, base_steps // 4), 1]), reverse=True)
468
+ traj = []
469
+ final_out = None
470
+ for s in step_grid:
471
+ out_s = _run_single_prediction(model_bundle, text, {"num_steps": int(s)})
472
+ if s == 1:
473
+ final_out = out_s
474
+ traj.append((s, out_s))
475
+ if final_out is None and traj:
476
+ final_out = traj[-1][1]
477
+ drift_rows = []
478
+ for s, out_s in traj:
479
+ d = _compute_cer(out_s, final_out or out_s)
480
+ drift_rows.append((s, round(d, 4), out_s[:56]))
481
+
482
  tfidf = _mini_tfidf_scores(text)
483
  top = sorted(tfidf.items(), key=lambda kv: kv[1], reverse=True)[:5]
484
+ traj_txt = "\n".join([f"steps={s:>3d} drift_to_final={d:.4f} out={o}" for s, d, o in drift_rows])
485
  return (
486
  f"[Live Task2]\n"
487
  f"Input: {text}\nPrediction: {pred}\n"
488
  f"Token-length={len(toks)} unique-ratio={uniq:.3f}\n"
489
+ f"TF-IDF(top): {top}\n"
490
+ f"Diffusion trajectory (live):\n{traj_txt}"
491
  )
492
 
493
  if str(task) == "3":
 
585
 
586
  def poll_run_all_background(job_id, output_dir):
587
  if not job_id or job_id not in _BG_JOBS:
588
+ msg = "Background job idle. You can run a single task or start Run All 5 in background."
589
  empty = refresh_task_outputs(output_dir)
590
  return msg, msg, *empty
591
  j = _BG_JOBS[job_id]
 
597
  return status, j.get("log", ""), *outputs
598
 
599
 
600
+ def run_single_task_and_refresh(model_bundle, task, output_dir, input_text, task4_phase):
601
+ status, log = run_single_task(model_bundle, task, output_dir, input_text, task4_phase)
602
+ out = refresh_task_outputs(output_dir)
603
+ return status, log, *out
604
+
605
+
606
  def run_single_task(model_bundle, task, output_dir, input_text, task4_phase):
607
  if not model_bundle:
608
  raise gr.Error("Load a model first.")
 
667
 
668
  task2_drift = _img_or_none(os.path.join(output_dir, "task2_semantic_drift.png"))
669
  task2_attn = _img_or_none(os.path.join(output_dir, "task2_attn_t0.png"))
670
+ task2_evolution = _img_or_none(os.path.join(output_dir, "task2_attn_evolution.png"))
671
+ # Show farthest diffusion step snapshot if available (t=max).
672
+ task2_tmax = None
673
+ try:
674
+ cands = []
675
+ for name in os.listdir(output_dir):
676
+ if name.startswith("task2_attn_t") and name.endswith(".png"):
677
+ step = name.replace("task2_attn_t", "").replace(".png", "")
678
+ if step.isdigit():
679
+ cands.append((int(step), os.path.join(output_dir, name)))
680
+ if cands:
681
+ cands.sort(key=lambda x: x[0], reverse=True)
682
+ task2_tmax = cands[0][1]
683
+ except Exception:
684
+ task2_tmax = None
685
  task3_space = _img_or_none(os.path.join(output_dir, "task3_concept_space.png"))
686
  task4_plot = _img_or_none(os.path.join(output_dir, "task4_ablation_3d.png"))
687
  if task4_plot is None:
688
  task4_plot = _img_or_none(os.path.join(output_dir, "task4_3d.png"))
689
+ return (
690
+ task1_txt, task2_txt, task2_drift, task2_attn, task2_tmax, task2_evolution,
691
+ task3_txt, task3_space, task5_txt, task4_plot
692
+ )
693
 
694
 
695
  CUSTOM_CSS = """
 
789
  with gr.Row():
790
  task2_drift_img = gr.Image(label="Task2 Drift", type="filepath")
791
  task2_attn_img = gr.Image(label="Task2 Attention", type="filepath")
792
+ with gr.Row():
793
+ task2_tmax_img = gr.Image(label="Task2 Attention (t=max)", type="filepath")
794
+ task2_evolution_img = gr.Image(label="Task2 Evolution", type="filepath")
795
  task3_box = gr.Textbox(label="Task 3 Report", lines=10, interactive=False)
796
  task3_img = gr.Image(label="Task3 Concept Space", type="filepath")
797
  task5_box = gr.Textbox(label="Task 5 Report", lines=10, interactive=False)
 
886
  )
887
 
888
  run_single_btn.click(
889
+ fn=run_single_task_and_refresh,
890
  inputs=[model_state, task_choice, analysis_output_dir, analysis_input, task4_phase],
891
+ outputs=[
892
+ task_run_status,
893
+ task_run_log,
894
+ task1_box,
895
+ task2_box,
896
+ task2_drift_img,
897
+ task2_attn_img,
898
+ task2_tmax_img,
899
+ task2_evolution_img,
900
+ task3_box,
901
+ task3_img,
902
+ task5_box,
903
+ task4_img,
904
+ ],
905
  )
906
  run_all_btn.click(
907
  fn=start_run_all_background,
 
918
  task2_box,
919
  task2_drift_img,
920
  task2_attn_img,
921
+ task2_tmax_img,
922
+ task2_evolution_img,
923
  task3_box,
924
  task3_img,
925
  task5_box,
 
934
  task2_box,
935
  task2_drift_img,
936
  task2_attn_img,
937
+ task2_tmax_img,
938
+ task2_evolution_img,
939
  task3_box,
940
  task3_img,
941
  task5_box,
 
954
  task2_box,
955
  task2_drift_img,
956
  task2_attn_img,
957
+ task2_tmax_img,
958
+ task2_evolution_img,
959
  task3_box,
960
  task3_img,
961
  task5_box,