bhsinghgrid commited on
Commit
483e2dc
·
verified ·
1 Parent(s): 2fdbfb0

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +160 -37
app.py CHANGED
@@ -38,6 +38,20 @@ except Exception:
38
  mlflow = None
39
 
40
  _MLFLOW_READY = False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
 
43
  def _setup_mlflow_once():
@@ -71,6 +85,41 @@ def _mlflow_event(run_name: str, params: dict | None = None, metrics: dict | Non
71
  except Exception:
72
  pass
73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  HF_DEFAULT_MODEL_REPO = os.environ.get("HF_DEFAULT_MODEL_REPO", "bhsinghgrid/DevaFlow")
75
  HF_DEFAULT_MODEL_FILE = os.environ.get("HF_DEFAULT_MODEL_FILE", "best_model.pt")
76
 
@@ -384,7 +433,7 @@ def _resolve_analysis_script() -> Path | None:
384
  return None
385
 
386
 
387
- def _run_analysis_cmd(task, ckpt_path, output_dir, input_text="dharmo rakṣati rakṣitaḥ", phase="analyze"):
388
  os.makedirs(output_dir, exist_ok=True)
389
  script = _resolve_analysis_script()
390
  if script is None:
@@ -420,6 +469,8 @@ def _run_analysis_cmd(task, ckpt_path, output_dir, input_text="dharmo rakṣati
420
  cmd.extend(["--input", input_text])
421
  if str(task) == "4":
422
  cmd.extend(["--phase", phase])
 
 
423
 
424
  env = os.environ.copy()
425
  env.setdefault("HF_HOME", "/tmp/hf_home")
@@ -502,7 +553,7 @@ def _run_single_prediction(model_bundle, text: str, cfg_override: dict | None =
502
  return _decode_with_cleanup(tgt_tok, out[0].tolist(), text.strip(), cfg["inference"])
503
 
504
 
505
- def _live_task_analysis(model_bundle, task: str, input_text: str) -> str:
506
  text = input_text.strip()
507
  if not text:
508
  return "Live analysis skipped: empty input."
@@ -567,7 +618,7 @@ def _live_task_analysis(model_bundle, task: str, input_text: str) -> str:
567
 
568
  if str(task) == "5":
569
  ref = _iast_to_deva(text)
570
- scales = [0.0, 0.5, 1.0, 1.5, 2.0]
571
  rows = []
572
  for s in scales:
573
  cfg_map = {
@@ -582,13 +633,14 @@ def _live_task_analysis(model_bundle, task: str, input_text: str) -> str:
582
  return _live_input_summary(model_bundle, text)
583
 
584
 
585
- def _bg_worker(job_id: str, model_bundle, output_dir: str, input_text: str, task4_phase: str):
586
  tasks = ["1", "2", "3", "4", "5"]
587
  failures = 0
588
  logs = []
589
  run_start = time.perf_counter()
590
  _BG_JOBS[job_id].update({"state": "running", "progress": 0, "failures": 0, "updated": datetime.now().isoformat()})
591
  for idx, task in enumerate(tasks, start=1):
 
592
  _BG_JOBS[job_id].update(
593
  {
594
  "state": f"running task {task}",
@@ -596,13 +648,24 @@ def _bg_worker(job_id: str, model_bundle, output_dir: str, input_text: str, task
596
  "updated": datetime.now().isoformat(),
597
  }
598
  )
599
- code, log, used_bundled = _run_analysis_cmd(task, model_bundle["ckpt_path"], output_dir, input_text, task4_phase)
 
 
 
 
 
 
 
600
  logs.append(f"\n\n{'='*22} TASK {task} {'='*22}\n{log}")
601
  if code != 0:
602
  failures += 1
603
- logs.append(f"\n[Live fallback]\n{_live_task_analysis(model_bundle, task, input_text)}\n")
 
604
  elif used_bundled:
605
- logs.append(f"\n[Live bundled summary]\n{_live_task_analysis(model_bundle, task, input_text)}\n")
 
 
 
606
  _BG_JOBS[job_id].update(
607
  {
608
  "log": "".join(logs),
@@ -655,7 +718,7 @@ def _bg_worker(job_id: str, model_bundle, output_dir: str, input_text: str, task
655
  )
656
 
657
 
658
- def start_run_all_background(model_bundle, output_dir, input_text, task4_phase):
659
  if not model_bundle:
660
  raise gr.Error("Load a model first.")
661
  os.makedirs(output_dir, exist_ok=True)
@@ -669,53 +732,64 @@ def start_run_all_background(model_bundle, output_dir, input_text, task4_phase):
669
  "output_dir": output_dir,
670
  "created": datetime.now().isoformat(),
671
  "updated": datetime.now().isoformat(),
 
672
  }
673
  th = threading.Thread(
674
  target=_bg_worker,
675
- args=(job_id, model_bundle, output_dir, input_text, task4_phase),
676
  daemon=True,
677
  )
678
  th.start()
679
- return f"Background run started. Job ID: {job_id}", f"Job {job_id} queued...", job_id
 
680
 
681
 
682
  def poll_run_all_background(job_id, output_dir):
683
  if not job_id or job_id not in _BG_JOBS:
684
  msg = "Background job idle. You can run a single task or start Run All 5 in background."
685
  empty = refresh_task_outputs(output_dir)
686
- return msg, msg, *empty
 
687
  j = _BG_JOBS[job_id]
688
  status = (
689
  f"Job {job_id} | state={j['state']} | progress={j['progress']}% | "
690
  f"failures={j['failures']} | updated={j['updated']}"
691
  )
692
  outputs = refresh_task_outputs(output_dir)
693
- return status, j.get("log", ""), *outputs
 
694
 
695
 
696
- def run_single_task_and_refresh(model_bundle, task, output_dir, input_text, task4_phase):
697
- status, log = run_single_task(model_bundle, task, output_dir, input_text, task4_phase)
698
  out = refresh_task_outputs(output_dir)
699
- return status, log, *out
700
 
701
 
702
- def run_single_task(model_bundle, task, output_dir, input_text, task4_phase):
703
  if not model_bundle:
704
  raise gr.Error("Load a model first.")
705
  t0 = time.perf_counter()
706
- code, log, used_bundled = _run_analysis_cmd(task, model_bundle["ckpt_path"], output_dir, input_text, task4_phase)
 
 
 
 
707
  elapsed = (time.perf_counter() - t0) * 1000.0
708
  if code != 0:
709
  _bundle_task_outputs(model_bundle, output_dir)
710
- log = f"{log}\n\n--- Live task analysis ---\n{_live_task_analysis(model_bundle, task, input_text)}"
711
  status = f"Task {task} fallback mode: bundled reports + live input analysis."
 
712
  else:
713
  if used_bundled:
714
  _bundle_task_outputs(model_bundle, output_dir)
715
- log = f"{log}\n\n--- Live task analysis ---\n{_live_task_analysis(model_bundle, task, input_text)}"
716
  status = f"Task {task} loaded from bundled analysis outputs + live analysis."
 
717
  else:
718
  status = f"Task {task} completed (exit={code})."
 
719
  _mlflow_event(
720
  run_name=f"space_task_{task}",
721
  params={
@@ -731,10 +805,11 @@ def run_single_task(model_bundle, task, output_dir, input_text, task4_phase):
731
  },
732
  tags={"source": "hf_space", "mode": "single_task"},
733
  )
734
- return status, log
 
735
 
736
 
737
- def run_all_tasks(model_bundle, output_dir, input_text, task4_phase):
738
  if not model_bundle:
739
  raise gr.Error("Load a model first.")
740
  logs = []
@@ -742,7 +817,7 @@ def run_all_tasks(model_bundle, output_dir, input_text, task4_phase):
742
  used_bundled_any = False
743
  for task in ["1", "2", "3", "4", "5"]:
744
  code, log, used_bundled = _run_analysis_cmd(
745
- task, model_bundle["ckpt_path"], output_dir, input_text, task4_phase
746
  )
747
  logs.append(f"\n\n{'='*22} TASK {task} {'='*22}\n{log}")
748
  used_bundled_any = used_bundled_any or used_bundled
@@ -813,12 +888,16 @@ def _safe_refresh_task_outputs(output_dir):
813
  return (err, err, None, None, None, None, err, None, err, None)
814
 
815
 
816
- def _safe_start_run_all_background(model_bundle, output_dir, input_text, task4_phase, current_job_id):
 
 
817
  try:
818
- status, log, job_id = start_run_all_background(model_bundle, output_dir, input_text, task4_phase)
819
- return status, log, job_id
 
820
  except Exception as e:
821
- return f"Background start failed: {e}", f"Background start failed: {e}", current_job_id
 
822
 
823
 
824
  def _safe_poll_run_all_background(job_id, output_dir):
@@ -827,16 +906,43 @@ def _safe_poll_run_all_background(job_id, output_dir):
827
  except Exception as e:
828
  err = f"Track error: {e}"
829
  out = _safe_refresh_task_outputs(output_dir)
830
- return err, err, *out
831
 
832
 
833
- def _safe_run_single_task_and_refresh(model_bundle, task, output_dir, input_text, task4_phase):
 
 
834
  try:
835
- return run_single_task_and_refresh(model_bundle, task, output_dir, input_text, task4_phase)
 
836
  except Exception as e:
837
  err = f"Task {task} failed: {e}"
838
  out = _safe_refresh_task_outputs(output_dir)
839
- return err, err, *out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
840
 
841
 
842
  CUSTOM_CSS = """
@@ -895,6 +1001,8 @@ with gr.Blocks(title="Sanskrit Diffusion Client Demo", css=CUSTOM_CSS) as demo:
895
  init_msg = "Select a model and load." if checkpoint_map() else "No checkpoints found in ablation_results/ or results*/."
896
  load_status = gr.Markdown(init_msg)
897
  model_info = gr.JSON(label="Loaded Model Details")
 
 
898
 
899
  with gr.Tabs():
900
  with gr.Tab("1) Task Runner"):
@@ -915,6 +1023,11 @@ with gr.Blocks(title="Sanskrit Diffusion Client Demo", css=CUSTOM_CSS) as demo:
915
  value="analyze",
916
  label="Task 4 Phase",
917
  )
 
 
 
 
 
918
  run_all_btn = gr.Button("Run All 5 Tasks (Background)", variant="primary")
919
  track_bg_btn = gr.Button("Track Background Run")
920
 
@@ -1004,7 +1117,7 @@ with gr.Blocks(title="Sanskrit Diffusion Client Demo", css=CUSTOM_CSS) as demo:
1004
  )
1005
 
1006
  generate_btn.click(
1007
- fn=generate_from_ui,
1008
  inputs=[
1009
  model_state,
1010
  input_text,
@@ -1015,10 +1128,10 @@ with gr.Blocks(title="Sanskrit Diffusion Client Demo", css=CUSTOM_CSS) as demo:
1015
  num_steps,
1016
  clean_output,
1017
  ],
1018
- outputs=[output_text, run_status, run_record],
1019
  )
1020
  input_text.submit(
1021
- fn=generate_from_ui,
1022
  inputs=[
1023
  model_state,
1024
  input_text,
@@ -1029,15 +1142,20 @@ with gr.Blocks(title="Sanskrit Diffusion Client Demo", css=CUSTOM_CSS) as demo:
1029
  num_steps,
1030
  clean_output,
1031
  ],
1032
- outputs=[output_text, run_status, run_record],
1033
  )
1034
 
1035
  run_single_btn.click(
1036
  fn=_safe_run_single_task_and_refresh,
1037
- inputs=[model_state, task_choice, analysis_output_dir, analysis_input, task4_phase],
 
 
 
1038
  outputs=[
1039
  task_run_status,
1040
  task_run_log,
 
 
1041
  task1_box,
1042
  task2_box,
1043
  task2_drift_img,
@@ -1052,8 +1170,11 @@ with gr.Blocks(title="Sanskrit Diffusion Client Demo", css=CUSTOM_CSS) as demo:
1052
  )
1053
  run_all_btn.click(
1054
  fn=_safe_start_run_all_background,
1055
- inputs=[model_state, analysis_output_dir, analysis_input, task4_phase, bg_job_state],
1056
- outputs=[task_run_status, task_run_log, bg_job_state],
 
 
 
1057
  )
1058
  track_bg_btn.click(
1059
  fn=_safe_poll_run_all_background,
@@ -1061,6 +1182,8 @@ with gr.Blocks(title="Sanskrit Diffusion Client Demo", css=CUSTOM_CSS) as demo:
1061
  outputs=[
1062
  task_run_status,
1063
  task_run_log,
 
 
1064
  task1_box,
1065
  task2_box,
1066
  task2_drift_img,
 
38
  mlflow = None
39
 
40
  _MLFLOW_READY = False
41
+ FLOW_STEPS = [
42
+ "Start",
43
+ "Load Model (checkpoint/config/device/eval)",
44
+ "Load Tokenizers",
45
+ "Input (IAST)",
46
+ "Source Tokenization",
47
+ "Encoder (run once)",
48
+ "KV Cache prepared",
49
+ "Initialize x_T (MASK)",
50
+ "Diffusion loop (T→0, with Task2/Task3 hooks)",
51
+ "Final x0",
52
+ "Decode to Devanagari",
53
+ "Evaluation/Tasks (Task4/Task5)",
54
+ ]
55
 
56
 
57
  def _setup_mlflow_once():
 
85
  except Exception:
86
  pass
87
 
88
+
89
+ def _build_flow_markdown(model_loaded=False, inference_ready=False, task_states=None):
90
+ lines = ["### Execution Flow"]
91
+ for i, step in enumerate(FLOW_STEPS, start=1):
92
+ status = "⬜"
93
+ if model_loaded and i <= 3:
94
+ status = "✅"
95
+ if inference_ready and i <= 11:
96
+ status = "✅"
97
+ lines.append(f"{status} {i}. {step}")
98
+ if task_states:
99
+ lines.append("")
100
+ lines.append("### Task Status")
101
+ for k in ["1", "2", "3", "4", "5"]:
102
+ v = task_states.get(k, "pending")
103
+ icon = "✅" if v == "done" else ("🔄" if v.startswith("running") else ("❌" if v == "failed" else "⬜"))
104
+ lines.append(f"{icon} Task {k}: {v}")
105
+ return "\n".join(lines)
106
+
107
+
108
+ def _task5_cfg(lambda_min, lambda_max, lambda_step, task5_samples):
109
+ lo = float(lambda_min)
110
+ hi = float(lambda_max)
111
+ st = max(0.1, float(lambda_step))
112
+ if hi < lo:
113
+ lo, hi = hi, lo
114
+ vals = []
115
+ cur = lo
116
+ while cur <= hi + 1e-9 and len(vals) < 30:
117
+ vals.append(round(cur, 2))
118
+ cur += st
119
+ if not vals:
120
+ vals = [0.0, 0.5, 1.0, 1.5, 2.0, 3.0]
121
+ return {"scales": vals, "samples": max(5, int(task5_samples))}
122
+
123
  HF_DEFAULT_MODEL_REPO = os.environ.get("HF_DEFAULT_MODEL_REPO", "bhsinghgrid/DevaFlow")
124
  HF_DEFAULT_MODEL_FILE = os.environ.get("HF_DEFAULT_MODEL_FILE", "best_model.pt")
125
 
 
433
  return None
434
 
435
 
436
+ def _run_analysis_cmd(task, ckpt_path, output_dir, input_text="dharmo rakṣati rakṣitaḥ", phase="analyze", task5_samples=50):
437
  os.makedirs(output_dir, exist_ok=True)
438
  script = _resolve_analysis_script()
439
  if script is None:
 
469
  cmd.extend(["--input", input_text])
470
  if str(task) == "4":
471
  cmd.extend(["--phase", phase])
472
+ if str(task) == "5":
473
+ cmd.extend(["--task5_samples", str(int(task5_samples))])
474
 
475
  env = os.environ.copy()
476
  env.setdefault("HF_HOME", "/tmp/hf_home")
 
553
  return _decode_with_cleanup(tgt_tok, out[0].tolist(), text.strip(), cfg["inference"])
554
 
555
 
556
+ def _live_task_analysis(model_bundle, task: str, input_text: str, task5_cfg: dict | None = None) -> str:
557
  text = input_text.strip()
558
  if not text:
559
  return "Live analysis skipped: empty input."
 
618
 
619
  if str(task) == "5":
620
  ref = _iast_to_deva(text)
621
+ scales = (task5_cfg or {}).get("scales", [0.0, 0.5, 1.0, 1.5, 2.0])
622
  rows = []
623
  for s in scales:
624
  cfg_map = {
 
633
  return _live_input_summary(model_bundle, text)
634
 
635
 
636
+ def _bg_worker(job_id: str, model_bundle, output_dir: str, input_text: str, task4_phase: str, task5_cfg: dict):
637
  tasks = ["1", "2", "3", "4", "5"]
638
  failures = 0
639
  logs = []
640
  run_start = time.perf_counter()
641
  _BG_JOBS[job_id].update({"state": "running", "progress": 0, "failures": 0, "updated": datetime.now().isoformat()})
642
  for idx, task in enumerate(tasks, start=1):
643
+ _BG_JOBS[job_id]["task_states"][task] = "running"
644
  _BG_JOBS[job_id].update(
645
  {
646
  "state": f"running task {task}",
 
648
  "updated": datetime.now().isoformat(),
649
  }
650
  )
651
+ code, log, used_bundled = _run_analysis_cmd(
652
+ task,
653
+ model_bundle["ckpt_path"],
654
+ output_dir,
655
+ input_text,
656
+ task4_phase,
657
+ task5_cfg.get("samples", 50),
658
+ )
659
  logs.append(f"\n\n{'='*22} TASK {task} {'='*22}\n{log}")
660
  if code != 0:
661
  failures += 1
662
+ _BG_JOBS[job_id]["task_states"][task] = "failed"
663
+ logs.append(f"\n[Live fallback]\n{_live_task_analysis(model_bundle, task, input_text, task5_cfg)}\n")
664
  elif used_bundled:
665
+ _BG_JOBS[job_id]["task_states"][task] = "done(bundled)"
666
+ logs.append(f"\n[Live bundled summary]\n{_live_task_analysis(model_bundle, task, input_text, task5_cfg)}\n")
667
+ else:
668
+ _BG_JOBS[job_id]["task_states"][task] = "done"
669
  _BG_JOBS[job_id].update(
670
  {
671
  "log": "".join(logs),
 
718
  )
719
 
720
 
721
+ def start_run_all_background(model_bundle, output_dir, input_text, task4_phase, task5_cfg):
722
  if not model_bundle:
723
  raise gr.Error("Load a model first.")
724
  os.makedirs(output_dir, exist_ok=True)
 
732
  "output_dir": output_dir,
733
  "created": datetime.now().isoformat(),
734
  "updated": datetime.now().isoformat(),
735
+ "task_states": {k: "pending" for k in ["1", "2", "3", "4", "5"]},
736
  }
737
  th = threading.Thread(
738
  target=_bg_worker,
739
+ args=(job_id, model_bundle, output_dir, input_text, task4_phase, task5_cfg),
740
  daemon=True,
741
  )
742
  th.start()
743
+ flow = _build_flow_markdown(model_loaded=True, inference_ready=False, task_states=_BG_JOBS[job_id]["task_states"])
744
+ return f"Background run started. Job ID: {job_id}", f"Job {job_id} queued...", job_id, _BG_JOBS[job_id]["task_states"], flow
745
 
746
 
747
  def poll_run_all_background(job_id, output_dir):
748
  if not job_id or job_id not in _BG_JOBS:
749
  msg = "Background job idle. You can run a single task or start Run All 5 in background."
750
  empty = refresh_task_outputs(output_dir)
751
+ flow = _build_flow_markdown(model_loaded=False, inference_ready=False, task_states={})
752
+ return msg, msg, {}, flow, *empty
753
  j = _BG_JOBS[job_id]
754
  status = (
755
  f"Job {job_id} | state={j['state']} | progress={j['progress']}% | "
756
  f"failures={j['failures']} | updated={j['updated']}"
757
  )
758
  outputs = refresh_task_outputs(output_dir)
759
+ flow = _build_flow_markdown(model_loaded=True, inference_ready=False, task_states=j.get("task_states", {}))
760
+ return status, j.get("log", ""), j.get("task_states", {}), flow, *outputs
761
 
762
 
763
+ def run_single_task_and_refresh(model_bundle, task, output_dir, input_text, task4_phase, task5_cfg):
764
+ status, log, task_states, flow = run_single_task(model_bundle, task, output_dir, input_text, task4_phase, task5_cfg)
765
  out = refresh_task_outputs(output_dir)
766
+ return status, log, task_states, flow, *out
767
 
768
 
769
+ def run_single_task(model_bundle, task, output_dir, input_text, task4_phase, task5_cfg):
770
  if not model_bundle:
771
  raise gr.Error("Load a model first.")
772
  t0 = time.perf_counter()
773
+ code, log, used_bundled = _run_analysis_cmd(
774
+ task, model_bundle["ckpt_path"], output_dir, input_text, task4_phase, task5_cfg.get("samples", 50)
775
+ )
776
+ task_states = {k: "pending" for k in ["1", "2", "3", "4", "5"]}
777
+ task_states[str(task)] = "running"
778
  elapsed = (time.perf_counter() - t0) * 1000.0
779
  if code != 0:
780
  _bundle_task_outputs(model_bundle, output_dir)
781
+ log = f"{log}\n\n--- Live task analysis ---\n{_live_task_analysis(model_bundle, task, input_text, task5_cfg)}"
782
  status = f"Task {task} fallback mode: bundled reports + live input analysis."
783
+ task_states[str(task)] = "failed"
784
  else:
785
  if used_bundled:
786
  _bundle_task_outputs(model_bundle, output_dir)
787
+ log = f"{log}\n\n--- Live task analysis ---\n{_live_task_analysis(model_bundle, task, input_text, task5_cfg)}"
788
  status = f"Task {task} loaded from bundled analysis outputs + live analysis."
789
+ task_states[str(task)] = "done(bundled)"
790
  else:
791
  status = f"Task {task} completed (exit={code})."
792
+ task_states[str(task)] = "done"
793
  _mlflow_event(
794
  run_name=f"space_task_{task}",
795
  params={
 
805
  },
806
  tags={"source": "hf_space", "mode": "single_task"},
807
  )
808
+ flow = _build_flow_markdown(model_loaded=True, inference_ready=False, task_states=task_states)
809
+ return status, log, task_states, flow
810
 
811
 
812
+ def run_all_tasks(model_bundle, output_dir, input_text, task4_phase, task5_cfg):
813
  if not model_bundle:
814
  raise gr.Error("Load a model first.")
815
  logs = []
 
817
  used_bundled_any = False
818
  for task in ["1", "2", "3", "4", "5"]:
819
  code, log, used_bundled = _run_analysis_cmd(
820
+ task, model_bundle["ckpt_path"], output_dir, input_text, task4_phase, task5_cfg.get("samples", 50)
821
  )
822
  logs.append(f"\n\n{'='*22} TASK {task} {'='*22}\n{log}")
823
  used_bundled_any = used_bundled_any or used_bundled
 
888
  return (err, err, None, None, None, None, err, None, err, None)
889
 
890
 
891
+ def _safe_start_run_all_background(
892
+ model_bundle, output_dir, input_text, task4_phase, current_job_id, lambda_min, lambda_max, lambda_step, task5_samples
893
+ ):
894
  try:
895
+ cfg = _task5_cfg(lambda_min, lambda_max, lambda_step, task5_samples)
896
+ status, log, job_id, task_states, flow = start_run_all_background(model_bundle, output_dir, input_text, task4_phase, cfg)
897
+ return status, log, job_id, task_states, flow
898
  except Exception as e:
899
+ err = f"Background start failed: {e}"
900
+ return err, err, current_job_id, {}, _build_flow_markdown(model_loaded=bool(model_bundle), inference_ready=False, task_states={})
901
 
902
 
903
  def _safe_poll_run_all_background(job_id, output_dir):
 
906
  except Exception as e:
907
  err = f"Track error: {e}"
908
  out = _safe_refresh_task_outputs(output_dir)
909
+ return err, err, {}, _build_flow_markdown(model_loaded=False, inference_ready=False, task_states={}), *out
910
 
911
 
912
+ def _safe_run_single_task_and_refresh(
913
+ model_bundle, task, output_dir, input_text, task4_phase, lambda_min, lambda_max, lambda_step, task5_samples
914
+ ):
915
  try:
916
+ cfg = _task5_cfg(lambda_min, lambda_max, lambda_step, task5_samples)
917
+ return run_single_task_and_refresh(model_bundle, task, output_dir, input_text, task4_phase, cfg)
918
  except Exception as e:
919
  err = f"Task {task} failed: {e}"
920
  out = _safe_refresh_task_outputs(output_dir)
921
+ return err, err, {}, _build_flow_markdown(model_loaded=bool(model_bundle), inference_ready=False, task_states={}), *out
922
+
923
+
924
+ def _generate_with_flow(
925
+ model_bundle,
926
+ input_text,
927
+ temperature,
928
+ top_k,
929
+ repetition_penalty,
930
+ diversity_penalty,
931
+ num_steps,
932
+ clean_output,
933
+ ):
934
+ out_text, status, meta = generate_from_ui(
935
+ model_bundle,
936
+ input_text,
937
+ temperature,
938
+ top_k,
939
+ repetition_penalty,
940
+ diversity_penalty,
941
+ num_steps,
942
+ clean_output,
943
+ )
944
+ flow = _build_flow_markdown(model_loaded=True, inference_ready=True, task_states={})
945
+ return out_text, status, meta, flow
946
 
947
 
948
  CUSTOM_CSS = """
 
1001
  init_msg = "Select a model and load." if checkpoint_map() else "No checkpoints found in ablation_results/ or results*/."
1002
  load_status = gr.Markdown(init_msg)
1003
  model_info = gr.JSON(label="Loaded Model Details")
1004
+ flow_box = gr.Markdown(_build_flow_markdown(model_loaded=False, inference_ready=False, task_states={}))
1005
+ task_states_view = gr.JSON(label="Task Execution State (side-by-side)")
1006
 
1007
  with gr.Tabs():
1008
  with gr.Tab("1) Task Runner"):
 
1023
  value="analyze",
1024
  label="Task 4 Phase",
1025
  )
1026
+ gr.Markdown("**Task 5 Controls**")
1027
+ task5_lambda_min = gr.Slider(0.0, 3.0, value=0.0, step=0.1, label="Task5 λ min")
1028
+ task5_lambda_max = gr.Slider(0.0, 3.0, value=3.0, step=0.1, label="Task5 λ max")
1029
+ task5_lambda_step = gr.Slider(0.1, 1.0, value=0.5, step=0.1, label="Task5 λ step")
1030
+ task5_samples = gr.Slider(5, 200, value=50, step=5, label="Task5 sweep samples")
1031
  run_all_btn = gr.Button("Run All 5 Tasks (Background)", variant="primary")
1032
  track_bg_btn = gr.Button("Track Background Run")
1033
 
 
1117
  )
1118
 
1119
  generate_btn.click(
1120
+ fn=_generate_with_flow,
1121
  inputs=[
1122
  model_state,
1123
  input_text,
 
1128
  num_steps,
1129
  clean_output,
1130
  ],
1131
+ outputs=[output_text, run_status, run_record, flow_box],
1132
  )
1133
  input_text.submit(
1134
+ fn=_generate_with_flow,
1135
  inputs=[
1136
  model_state,
1137
  input_text,
 
1142
  num_steps,
1143
  clean_output,
1144
  ],
1145
+ outputs=[output_text, run_status, run_record, flow_box],
1146
  )
1147
 
1148
  run_single_btn.click(
1149
  fn=_safe_run_single_task_and_refresh,
1150
+ inputs=[
1151
+ model_state, task_choice, analysis_output_dir, analysis_input, task4_phase,
1152
+ task5_lambda_min, task5_lambda_max, task5_lambda_step, task5_samples
1153
+ ],
1154
  outputs=[
1155
  task_run_status,
1156
  task_run_log,
1157
+ task_states_view,
1158
+ flow_box,
1159
  task1_box,
1160
  task2_box,
1161
  task2_drift_img,
 
1170
  )
1171
  run_all_btn.click(
1172
  fn=_safe_start_run_all_background,
1173
+ inputs=[
1174
+ model_state, analysis_output_dir, analysis_input, task4_phase, bg_job_state,
1175
+ task5_lambda_min, task5_lambda_max, task5_lambda_step, task5_samples
1176
+ ],
1177
+ outputs=[task_run_status, task_run_log, bg_job_state, task_states_view, flow_box],
1178
  )
1179
  track_bg_btn.click(
1180
  fn=_safe_poll_run_all_background,
 
1182
  outputs=[
1183
  task_run_status,
1184
  task_run_log,
1185
+ task_states_view,
1186
+ flow_box,
1187
  task1_box,
1188
  task2_box,
1189
  task2_drift_img,