Spaces:
Running
Running
Upload folder using huggingface_hub
Browse files
app.py
CHANGED
|
@@ -38,6 +38,20 @@ except Exception:
|
|
| 38 |
mlflow = None
|
| 39 |
|
| 40 |
_MLFLOW_READY = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
|
| 42 |
|
| 43 |
def _setup_mlflow_once():
|
|
@@ -71,6 +85,41 @@ def _mlflow_event(run_name: str, params: dict | None = None, metrics: dict | Non
|
|
| 71 |
except Exception:
|
| 72 |
pass
|
| 73 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
HF_DEFAULT_MODEL_REPO = os.environ.get("HF_DEFAULT_MODEL_REPO", "bhsinghgrid/DevaFlow")
|
| 75 |
HF_DEFAULT_MODEL_FILE = os.environ.get("HF_DEFAULT_MODEL_FILE", "best_model.pt")
|
| 76 |
|
|
@@ -384,7 +433,7 @@ def _resolve_analysis_script() -> Path | None:
|
|
| 384 |
return None
|
| 385 |
|
| 386 |
|
| 387 |
-
def _run_analysis_cmd(task, ckpt_path, output_dir, input_text="dharmo rakṣati rakṣitaḥ", phase="analyze"):
|
| 388 |
os.makedirs(output_dir, exist_ok=True)
|
| 389 |
script = _resolve_analysis_script()
|
| 390 |
if script is None:
|
|
@@ -420,6 +469,8 @@ def _run_analysis_cmd(task, ckpt_path, output_dir, input_text="dharmo rakṣati
|
|
| 420 |
cmd.extend(["--input", input_text])
|
| 421 |
if str(task) == "4":
|
| 422 |
cmd.extend(["--phase", phase])
|
|
|
|
|
|
|
| 423 |
|
| 424 |
env = os.environ.copy()
|
| 425 |
env.setdefault("HF_HOME", "/tmp/hf_home")
|
|
@@ -502,7 +553,7 @@ def _run_single_prediction(model_bundle, text: str, cfg_override: dict | None =
|
|
| 502 |
return _decode_with_cleanup(tgt_tok, out[0].tolist(), text.strip(), cfg["inference"])
|
| 503 |
|
| 504 |
|
| 505 |
-
def _live_task_analysis(model_bundle, task: str, input_text: str) -> str:
|
| 506 |
text = input_text.strip()
|
| 507 |
if not text:
|
| 508 |
return "Live analysis skipped: empty input."
|
|
@@ -567,7 +618,7 @@ def _live_task_analysis(model_bundle, task: str, input_text: str) -> str:
|
|
| 567 |
|
| 568 |
if str(task) == "5":
|
| 569 |
ref = _iast_to_deva(text)
|
| 570 |
-
scales = [0.0, 0.5, 1.0, 1.5, 2.0]
|
| 571 |
rows = []
|
| 572 |
for s in scales:
|
| 573 |
cfg_map = {
|
|
@@ -582,13 +633,14 @@ def _live_task_analysis(model_bundle, task: str, input_text: str) -> str:
|
|
| 582 |
return _live_input_summary(model_bundle, text)
|
| 583 |
|
| 584 |
|
| 585 |
-
def _bg_worker(job_id: str, model_bundle, output_dir: str, input_text: str, task4_phase: str):
|
| 586 |
tasks = ["1", "2", "3", "4", "5"]
|
| 587 |
failures = 0
|
| 588 |
logs = []
|
| 589 |
run_start = time.perf_counter()
|
| 590 |
_BG_JOBS[job_id].update({"state": "running", "progress": 0, "failures": 0, "updated": datetime.now().isoformat()})
|
| 591 |
for idx, task in enumerate(tasks, start=1):
|
|
|
|
| 592 |
_BG_JOBS[job_id].update(
|
| 593 |
{
|
| 594 |
"state": f"running task {task}",
|
|
@@ -596,13 +648,24 @@ def _bg_worker(job_id: str, model_bundle, output_dir: str, input_text: str, task
|
|
| 596 |
"updated": datetime.now().isoformat(),
|
| 597 |
}
|
| 598 |
)
|
| 599 |
-
code, log, used_bundled = _run_analysis_cmd(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 600 |
logs.append(f"\n\n{'='*22} TASK {task} {'='*22}\n{log}")
|
| 601 |
if code != 0:
|
| 602 |
failures += 1
|
| 603 |
-
|
|
|
|
| 604 |
elif used_bundled:
|
| 605 |
-
|
|
|
|
|
|
|
|
|
|
| 606 |
_BG_JOBS[job_id].update(
|
| 607 |
{
|
| 608 |
"log": "".join(logs),
|
|
@@ -655,7 +718,7 @@ def _bg_worker(job_id: str, model_bundle, output_dir: str, input_text: str, task
|
|
| 655 |
)
|
| 656 |
|
| 657 |
|
| 658 |
-
def start_run_all_background(model_bundle, output_dir, input_text, task4_phase):
|
| 659 |
if not model_bundle:
|
| 660 |
raise gr.Error("Load a model first.")
|
| 661 |
os.makedirs(output_dir, exist_ok=True)
|
|
@@ -669,53 +732,64 @@ def start_run_all_background(model_bundle, output_dir, input_text, task4_phase):
|
|
| 669 |
"output_dir": output_dir,
|
| 670 |
"created": datetime.now().isoformat(),
|
| 671 |
"updated": datetime.now().isoformat(),
|
|
|
|
| 672 |
}
|
| 673 |
th = threading.Thread(
|
| 674 |
target=_bg_worker,
|
| 675 |
-
args=(job_id, model_bundle, output_dir, input_text, task4_phase),
|
| 676 |
daemon=True,
|
| 677 |
)
|
| 678 |
th.start()
|
| 679 |
-
|
|
|
|
| 680 |
|
| 681 |
|
| 682 |
def poll_run_all_background(job_id, output_dir):
|
| 683 |
if not job_id or job_id not in _BG_JOBS:
|
| 684 |
msg = "Background job idle. You can run a single task or start Run All 5 in background."
|
| 685 |
empty = refresh_task_outputs(output_dir)
|
| 686 |
-
|
|
|
|
| 687 |
j = _BG_JOBS[job_id]
|
| 688 |
status = (
|
| 689 |
f"Job {job_id} | state={j['state']} | progress={j['progress']}% | "
|
| 690 |
f"failures={j['failures']} | updated={j['updated']}"
|
| 691 |
)
|
| 692 |
outputs = refresh_task_outputs(output_dir)
|
| 693 |
-
|
|
|
|
| 694 |
|
| 695 |
|
| 696 |
-
def run_single_task_and_refresh(model_bundle, task, output_dir, input_text, task4_phase):
|
| 697 |
-
status, log = run_single_task(model_bundle, task, output_dir, input_text, task4_phase)
|
| 698 |
out = refresh_task_outputs(output_dir)
|
| 699 |
-
return status, log, *out
|
| 700 |
|
| 701 |
|
| 702 |
-
def run_single_task(model_bundle, task, output_dir, input_text, task4_phase):
|
| 703 |
if not model_bundle:
|
| 704 |
raise gr.Error("Load a model first.")
|
| 705 |
t0 = time.perf_counter()
|
| 706 |
-
code, log, used_bundled = _run_analysis_cmd(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 707 |
elapsed = (time.perf_counter() - t0) * 1000.0
|
| 708 |
if code != 0:
|
| 709 |
_bundle_task_outputs(model_bundle, output_dir)
|
| 710 |
-
log = f"{log}\n\n--- Live task analysis ---\n{_live_task_analysis(model_bundle, task, input_text)}"
|
| 711 |
status = f"Task {task} fallback mode: bundled reports + live input analysis."
|
|
|
|
| 712 |
else:
|
| 713 |
if used_bundled:
|
| 714 |
_bundle_task_outputs(model_bundle, output_dir)
|
| 715 |
-
log = f"{log}\n\n--- Live task analysis ---\n{_live_task_analysis(model_bundle, task, input_text)}"
|
| 716 |
status = f"Task {task} loaded from bundled analysis outputs + live analysis."
|
|
|
|
| 717 |
else:
|
| 718 |
status = f"Task {task} completed (exit={code})."
|
|
|
|
| 719 |
_mlflow_event(
|
| 720 |
run_name=f"space_task_{task}",
|
| 721 |
params={
|
|
@@ -731,10 +805,11 @@ def run_single_task(model_bundle, task, output_dir, input_text, task4_phase):
|
|
| 731 |
},
|
| 732 |
tags={"source": "hf_space", "mode": "single_task"},
|
| 733 |
)
|
| 734 |
-
|
|
|
|
| 735 |
|
| 736 |
|
| 737 |
-
def run_all_tasks(model_bundle, output_dir, input_text, task4_phase):
|
| 738 |
if not model_bundle:
|
| 739 |
raise gr.Error("Load a model first.")
|
| 740 |
logs = []
|
|
@@ -742,7 +817,7 @@ def run_all_tasks(model_bundle, output_dir, input_text, task4_phase):
|
|
| 742 |
used_bundled_any = False
|
| 743 |
for task in ["1", "2", "3", "4", "5"]:
|
| 744 |
code, log, used_bundled = _run_analysis_cmd(
|
| 745 |
-
task, model_bundle["ckpt_path"], output_dir, input_text, task4_phase
|
| 746 |
)
|
| 747 |
logs.append(f"\n\n{'='*22} TASK {task} {'='*22}\n{log}")
|
| 748 |
used_bundled_any = used_bundled_any or used_bundled
|
|
@@ -813,12 +888,16 @@ def _safe_refresh_task_outputs(output_dir):
|
|
| 813 |
return (err, err, None, None, None, None, err, None, err, None)
|
| 814 |
|
| 815 |
|
| 816 |
-
def _safe_start_run_all_background(
|
|
|
|
|
|
|
| 817 |
try:
|
| 818 |
-
|
| 819 |
-
|
|
|
|
| 820 |
except Exception as e:
|
| 821 |
-
|
|
|
|
| 822 |
|
| 823 |
|
| 824 |
def _safe_poll_run_all_background(job_id, output_dir):
|
|
@@ -827,16 +906,43 @@ def _safe_poll_run_all_background(job_id, output_dir):
|
|
| 827 |
except Exception as e:
|
| 828 |
err = f"Track error: {e}"
|
| 829 |
out = _safe_refresh_task_outputs(output_dir)
|
| 830 |
-
return err, err, *out
|
| 831 |
|
| 832 |
|
| 833 |
-
def _safe_run_single_task_and_refresh(
|
|
|
|
|
|
|
| 834 |
try:
|
| 835 |
-
|
|
|
|
| 836 |
except Exception as e:
|
| 837 |
err = f"Task {task} failed: {e}"
|
| 838 |
out = _safe_refresh_task_outputs(output_dir)
|
| 839 |
-
return err, err, *out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 840 |
|
| 841 |
|
| 842 |
CUSTOM_CSS = """
|
|
@@ -895,6 +1001,8 @@ with gr.Blocks(title="Sanskrit Diffusion Client Demo", css=CUSTOM_CSS) as demo:
|
|
| 895 |
init_msg = "Select a model and load." if checkpoint_map() else "No checkpoints found in ablation_results/ or results*/."
|
| 896 |
load_status = gr.Markdown(init_msg)
|
| 897 |
model_info = gr.JSON(label="Loaded Model Details")
|
|
|
|
|
|
|
| 898 |
|
| 899 |
with gr.Tabs():
|
| 900 |
with gr.Tab("1) Task Runner"):
|
|
@@ -915,6 +1023,11 @@ with gr.Blocks(title="Sanskrit Diffusion Client Demo", css=CUSTOM_CSS) as demo:
|
|
| 915 |
value="analyze",
|
| 916 |
label="Task 4 Phase",
|
| 917 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 918 |
run_all_btn = gr.Button("Run All 5 Tasks (Background)", variant="primary")
|
| 919 |
track_bg_btn = gr.Button("Track Background Run")
|
| 920 |
|
|
@@ -1004,7 +1117,7 @@ with gr.Blocks(title="Sanskrit Diffusion Client Demo", css=CUSTOM_CSS) as demo:
|
|
| 1004 |
)
|
| 1005 |
|
| 1006 |
generate_btn.click(
|
| 1007 |
-
fn=
|
| 1008 |
inputs=[
|
| 1009 |
model_state,
|
| 1010 |
input_text,
|
|
@@ -1015,10 +1128,10 @@ with gr.Blocks(title="Sanskrit Diffusion Client Demo", css=CUSTOM_CSS) as demo:
|
|
| 1015 |
num_steps,
|
| 1016 |
clean_output,
|
| 1017 |
],
|
| 1018 |
-
outputs=[output_text, run_status, run_record],
|
| 1019 |
)
|
| 1020 |
input_text.submit(
|
| 1021 |
-
fn=
|
| 1022 |
inputs=[
|
| 1023 |
model_state,
|
| 1024 |
input_text,
|
|
@@ -1029,15 +1142,20 @@ with gr.Blocks(title="Sanskrit Diffusion Client Demo", css=CUSTOM_CSS) as demo:
|
|
| 1029 |
num_steps,
|
| 1030 |
clean_output,
|
| 1031 |
],
|
| 1032 |
-
outputs=[output_text, run_status, run_record],
|
| 1033 |
)
|
| 1034 |
|
| 1035 |
run_single_btn.click(
|
| 1036 |
fn=_safe_run_single_task_and_refresh,
|
| 1037 |
-
inputs=[
|
|
|
|
|
|
|
|
|
|
| 1038 |
outputs=[
|
| 1039 |
task_run_status,
|
| 1040 |
task_run_log,
|
|
|
|
|
|
|
| 1041 |
task1_box,
|
| 1042 |
task2_box,
|
| 1043 |
task2_drift_img,
|
|
@@ -1052,8 +1170,11 @@ with gr.Blocks(title="Sanskrit Diffusion Client Demo", css=CUSTOM_CSS) as demo:
|
|
| 1052 |
)
|
| 1053 |
run_all_btn.click(
|
| 1054 |
fn=_safe_start_run_all_background,
|
| 1055 |
-
inputs=[
|
| 1056 |
-
|
|
|
|
|
|
|
|
|
|
| 1057 |
)
|
| 1058 |
track_bg_btn.click(
|
| 1059 |
fn=_safe_poll_run_all_background,
|
|
@@ -1061,6 +1182,8 @@ with gr.Blocks(title="Sanskrit Diffusion Client Demo", css=CUSTOM_CSS) as demo:
|
|
| 1061 |
outputs=[
|
| 1062 |
task_run_status,
|
| 1063 |
task_run_log,
|
|
|
|
|
|
|
| 1064 |
task1_box,
|
| 1065 |
task2_box,
|
| 1066 |
task2_drift_img,
|
|
|
|
| 38 |
mlflow = None
|
| 39 |
|
| 40 |
_MLFLOW_READY = False
|
| 41 |
+
FLOW_STEPS = [
|
| 42 |
+
"Start",
|
| 43 |
+
"Load Model (checkpoint/config/device/eval)",
|
| 44 |
+
"Load Tokenizers",
|
| 45 |
+
"Input (IAST)",
|
| 46 |
+
"Source Tokenization",
|
| 47 |
+
"Encoder (run once)",
|
| 48 |
+
"KV Cache prepared",
|
| 49 |
+
"Initialize x_T (MASK)",
|
| 50 |
+
"Diffusion loop (T→0, with Task2/Task3 hooks)",
|
| 51 |
+
"Final x0",
|
| 52 |
+
"Decode to Devanagari",
|
| 53 |
+
"Evaluation/Tasks (Task4/Task5)",
|
| 54 |
+
]
|
| 55 |
|
| 56 |
|
| 57 |
def _setup_mlflow_once():
|
|
|
|
| 85 |
except Exception:
|
| 86 |
pass
|
| 87 |
|
| 88 |
+
|
| 89 |
+
def _build_flow_markdown(model_loaded=False, inference_ready=False, task_states=None):
|
| 90 |
+
lines = ["### Execution Flow"]
|
| 91 |
+
for i, step in enumerate(FLOW_STEPS, start=1):
|
| 92 |
+
status = "⬜"
|
| 93 |
+
if model_loaded and i <= 3:
|
| 94 |
+
status = "✅"
|
| 95 |
+
if inference_ready and i <= 11:
|
| 96 |
+
status = "✅"
|
| 97 |
+
lines.append(f"{status} {i}. {step}")
|
| 98 |
+
if task_states:
|
| 99 |
+
lines.append("")
|
| 100 |
+
lines.append("### Task Status")
|
| 101 |
+
for k in ["1", "2", "3", "4", "5"]:
|
| 102 |
+
v = task_states.get(k, "pending")
|
| 103 |
+
icon = "✅" if v == "done" else ("🔄" if v.startswith("running") else ("❌" if v == "failed" else "⬜"))
|
| 104 |
+
lines.append(f"{icon} Task {k}: {v}")
|
| 105 |
+
return "\n".join(lines)
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def _task5_cfg(lambda_min, lambda_max, lambda_step, task5_samples):
|
| 109 |
+
lo = float(lambda_min)
|
| 110 |
+
hi = float(lambda_max)
|
| 111 |
+
st = max(0.1, float(lambda_step))
|
| 112 |
+
if hi < lo:
|
| 113 |
+
lo, hi = hi, lo
|
| 114 |
+
vals = []
|
| 115 |
+
cur = lo
|
| 116 |
+
while cur <= hi + 1e-9 and len(vals) < 30:
|
| 117 |
+
vals.append(round(cur, 2))
|
| 118 |
+
cur += st
|
| 119 |
+
if not vals:
|
| 120 |
+
vals = [0.0, 0.5, 1.0, 1.5, 2.0, 3.0]
|
| 121 |
+
return {"scales": vals, "samples": max(5, int(task5_samples))}
|
| 122 |
+
|
| 123 |
HF_DEFAULT_MODEL_REPO = os.environ.get("HF_DEFAULT_MODEL_REPO", "bhsinghgrid/DevaFlow")
|
| 124 |
HF_DEFAULT_MODEL_FILE = os.environ.get("HF_DEFAULT_MODEL_FILE", "best_model.pt")
|
| 125 |
|
|
|
|
| 433 |
return None
|
| 434 |
|
| 435 |
|
| 436 |
+
def _run_analysis_cmd(task, ckpt_path, output_dir, input_text="dharmo rakṣati rakṣitaḥ", phase="analyze", task5_samples=50):
|
| 437 |
os.makedirs(output_dir, exist_ok=True)
|
| 438 |
script = _resolve_analysis_script()
|
| 439 |
if script is None:
|
|
|
|
| 469 |
cmd.extend(["--input", input_text])
|
| 470 |
if str(task) == "4":
|
| 471 |
cmd.extend(["--phase", phase])
|
| 472 |
+
if str(task) == "5":
|
| 473 |
+
cmd.extend(["--task5_samples", str(int(task5_samples))])
|
| 474 |
|
| 475 |
env = os.environ.copy()
|
| 476 |
env.setdefault("HF_HOME", "/tmp/hf_home")
|
|
|
|
| 553 |
return _decode_with_cleanup(tgt_tok, out[0].tolist(), text.strip(), cfg["inference"])
|
| 554 |
|
| 555 |
|
| 556 |
+
def _live_task_analysis(model_bundle, task: str, input_text: str, task5_cfg: dict | None = None) -> str:
|
| 557 |
text = input_text.strip()
|
| 558 |
if not text:
|
| 559 |
return "Live analysis skipped: empty input."
|
|
|
|
| 618 |
|
| 619 |
if str(task) == "5":
|
| 620 |
ref = _iast_to_deva(text)
|
| 621 |
+
scales = (task5_cfg or {}).get("scales", [0.0, 0.5, 1.0, 1.5, 2.0])
|
| 622 |
rows = []
|
| 623 |
for s in scales:
|
| 624 |
cfg_map = {
|
|
|
|
| 633 |
return _live_input_summary(model_bundle, text)
|
| 634 |
|
| 635 |
|
| 636 |
+
def _bg_worker(job_id: str, model_bundle, output_dir: str, input_text: str, task4_phase: str, task5_cfg: dict):
|
| 637 |
tasks = ["1", "2", "3", "4", "5"]
|
| 638 |
failures = 0
|
| 639 |
logs = []
|
| 640 |
run_start = time.perf_counter()
|
| 641 |
_BG_JOBS[job_id].update({"state": "running", "progress": 0, "failures": 0, "updated": datetime.now().isoformat()})
|
| 642 |
for idx, task in enumerate(tasks, start=1):
|
| 643 |
+
_BG_JOBS[job_id]["task_states"][task] = "running"
|
| 644 |
_BG_JOBS[job_id].update(
|
| 645 |
{
|
| 646 |
"state": f"running task {task}",
|
|
|
|
| 648 |
"updated": datetime.now().isoformat(),
|
| 649 |
}
|
| 650 |
)
|
| 651 |
+
code, log, used_bundled = _run_analysis_cmd(
|
| 652 |
+
task,
|
| 653 |
+
model_bundle["ckpt_path"],
|
| 654 |
+
output_dir,
|
| 655 |
+
input_text,
|
| 656 |
+
task4_phase,
|
| 657 |
+
task5_cfg.get("samples", 50),
|
| 658 |
+
)
|
| 659 |
logs.append(f"\n\n{'='*22} TASK {task} {'='*22}\n{log}")
|
| 660 |
if code != 0:
|
| 661 |
failures += 1
|
| 662 |
+
_BG_JOBS[job_id]["task_states"][task] = "failed"
|
| 663 |
+
logs.append(f"\n[Live fallback]\n{_live_task_analysis(model_bundle, task, input_text, task5_cfg)}\n")
|
| 664 |
elif used_bundled:
|
| 665 |
+
_BG_JOBS[job_id]["task_states"][task] = "done(bundled)"
|
| 666 |
+
logs.append(f"\n[Live bundled summary]\n{_live_task_analysis(model_bundle, task, input_text, task5_cfg)}\n")
|
| 667 |
+
else:
|
| 668 |
+
_BG_JOBS[job_id]["task_states"][task] = "done"
|
| 669 |
_BG_JOBS[job_id].update(
|
| 670 |
{
|
| 671 |
"log": "".join(logs),
|
|
|
|
| 718 |
)
|
| 719 |
|
| 720 |
|
| 721 |
+
def start_run_all_background(model_bundle, output_dir, input_text, task4_phase, task5_cfg):
|
| 722 |
if not model_bundle:
|
| 723 |
raise gr.Error("Load a model first.")
|
| 724 |
os.makedirs(output_dir, exist_ok=True)
|
|
|
|
| 732 |
"output_dir": output_dir,
|
| 733 |
"created": datetime.now().isoformat(),
|
| 734 |
"updated": datetime.now().isoformat(),
|
| 735 |
+
"task_states": {k: "pending" for k in ["1", "2", "3", "4", "5"]},
|
| 736 |
}
|
| 737 |
th = threading.Thread(
|
| 738 |
target=_bg_worker,
|
| 739 |
+
args=(job_id, model_bundle, output_dir, input_text, task4_phase, task5_cfg),
|
| 740 |
daemon=True,
|
| 741 |
)
|
| 742 |
th.start()
|
| 743 |
+
flow = _build_flow_markdown(model_loaded=True, inference_ready=False, task_states=_BG_JOBS[job_id]["task_states"])
|
| 744 |
+
return f"Background run started. Job ID: {job_id}", f"Job {job_id} queued...", job_id, _BG_JOBS[job_id]["task_states"], flow
|
| 745 |
|
| 746 |
|
| 747 |
def poll_run_all_background(job_id, output_dir):
|
| 748 |
if not job_id or job_id not in _BG_JOBS:
|
| 749 |
msg = "Background job idle. You can run a single task or start Run All 5 in background."
|
| 750 |
empty = refresh_task_outputs(output_dir)
|
| 751 |
+
flow = _build_flow_markdown(model_loaded=False, inference_ready=False, task_states={})
|
| 752 |
+
return msg, msg, {}, flow, *empty
|
| 753 |
j = _BG_JOBS[job_id]
|
| 754 |
status = (
|
| 755 |
f"Job {job_id} | state={j['state']} | progress={j['progress']}% | "
|
| 756 |
f"failures={j['failures']} | updated={j['updated']}"
|
| 757 |
)
|
| 758 |
outputs = refresh_task_outputs(output_dir)
|
| 759 |
+
flow = _build_flow_markdown(model_loaded=True, inference_ready=False, task_states=j.get("task_states", {}))
|
| 760 |
+
return status, j.get("log", ""), j.get("task_states", {}), flow, *outputs
|
| 761 |
|
| 762 |
|
| 763 |
+
def run_single_task_and_refresh(model_bundle, task, output_dir, input_text, task4_phase, task5_cfg):
|
| 764 |
+
status, log, task_states, flow = run_single_task(model_bundle, task, output_dir, input_text, task4_phase, task5_cfg)
|
| 765 |
out = refresh_task_outputs(output_dir)
|
| 766 |
+
return status, log, task_states, flow, *out
|
| 767 |
|
| 768 |
|
| 769 |
+
def run_single_task(model_bundle, task, output_dir, input_text, task4_phase, task5_cfg):
|
| 770 |
if not model_bundle:
|
| 771 |
raise gr.Error("Load a model first.")
|
| 772 |
t0 = time.perf_counter()
|
| 773 |
+
code, log, used_bundled = _run_analysis_cmd(
|
| 774 |
+
task, model_bundle["ckpt_path"], output_dir, input_text, task4_phase, task5_cfg.get("samples", 50)
|
| 775 |
+
)
|
| 776 |
+
task_states = {k: "pending" for k in ["1", "2", "3", "4", "5"]}
|
| 777 |
+
task_states[str(task)] = "running"
|
| 778 |
elapsed = (time.perf_counter() - t0) * 1000.0
|
| 779 |
if code != 0:
|
| 780 |
_bundle_task_outputs(model_bundle, output_dir)
|
| 781 |
+
log = f"{log}\n\n--- Live task analysis ---\n{_live_task_analysis(model_bundle, task, input_text, task5_cfg)}"
|
| 782 |
status = f"Task {task} fallback mode: bundled reports + live input analysis."
|
| 783 |
+
task_states[str(task)] = "failed"
|
| 784 |
else:
|
| 785 |
if used_bundled:
|
| 786 |
_bundle_task_outputs(model_bundle, output_dir)
|
| 787 |
+
log = f"{log}\n\n--- Live task analysis ---\n{_live_task_analysis(model_bundle, task, input_text, task5_cfg)}"
|
| 788 |
status = f"Task {task} loaded from bundled analysis outputs + live analysis."
|
| 789 |
+
task_states[str(task)] = "done(bundled)"
|
| 790 |
else:
|
| 791 |
status = f"Task {task} completed (exit={code})."
|
| 792 |
+
task_states[str(task)] = "done"
|
| 793 |
_mlflow_event(
|
| 794 |
run_name=f"space_task_{task}",
|
| 795 |
params={
|
|
|
|
| 805 |
},
|
| 806 |
tags={"source": "hf_space", "mode": "single_task"},
|
| 807 |
)
|
| 808 |
+
flow = _build_flow_markdown(model_loaded=True, inference_ready=False, task_states=task_states)
|
| 809 |
+
return status, log, task_states, flow
|
| 810 |
|
| 811 |
|
| 812 |
+
def run_all_tasks(model_bundle, output_dir, input_text, task4_phase, task5_cfg):
|
| 813 |
if not model_bundle:
|
| 814 |
raise gr.Error("Load a model first.")
|
| 815 |
logs = []
|
|
|
|
| 817 |
used_bundled_any = False
|
| 818 |
for task in ["1", "2", "3", "4", "5"]:
|
| 819 |
code, log, used_bundled = _run_analysis_cmd(
|
| 820 |
+
task, model_bundle["ckpt_path"], output_dir, input_text, task4_phase, task5_cfg.get("samples", 50)
|
| 821 |
)
|
| 822 |
logs.append(f"\n\n{'='*22} TASK {task} {'='*22}\n{log}")
|
| 823 |
used_bundled_any = used_bundled_any or used_bundled
|
|
|
|
| 888 |
return (err, err, None, None, None, None, err, None, err, None)
|
| 889 |
|
| 890 |
|
| 891 |
+
def _safe_start_run_all_background(
|
| 892 |
+
model_bundle, output_dir, input_text, task4_phase, current_job_id, lambda_min, lambda_max, lambda_step, task5_samples
|
| 893 |
+
):
|
| 894 |
try:
|
| 895 |
+
cfg = _task5_cfg(lambda_min, lambda_max, lambda_step, task5_samples)
|
| 896 |
+
status, log, job_id, task_states, flow = start_run_all_background(model_bundle, output_dir, input_text, task4_phase, cfg)
|
| 897 |
+
return status, log, job_id, task_states, flow
|
| 898 |
except Exception as e:
|
| 899 |
+
err = f"Background start failed: {e}"
|
| 900 |
+
return err, err, current_job_id, {}, _build_flow_markdown(model_loaded=bool(model_bundle), inference_ready=False, task_states={})
|
| 901 |
|
| 902 |
|
| 903 |
def _safe_poll_run_all_background(job_id, output_dir):
|
|
|
|
| 906 |
except Exception as e:
|
| 907 |
err = f"Track error: {e}"
|
| 908 |
out = _safe_refresh_task_outputs(output_dir)
|
| 909 |
+
return err, err, {}, _build_flow_markdown(model_loaded=False, inference_ready=False, task_states={}), *out
|
| 910 |
|
| 911 |
|
| 912 |
+
def _safe_run_single_task_and_refresh(
|
| 913 |
+
model_bundle, task, output_dir, input_text, task4_phase, lambda_min, lambda_max, lambda_step, task5_samples
|
| 914 |
+
):
|
| 915 |
try:
|
| 916 |
+
cfg = _task5_cfg(lambda_min, lambda_max, lambda_step, task5_samples)
|
| 917 |
+
return run_single_task_and_refresh(model_bundle, task, output_dir, input_text, task4_phase, cfg)
|
| 918 |
except Exception as e:
|
| 919 |
err = f"Task {task} failed: {e}"
|
| 920 |
out = _safe_refresh_task_outputs(output_dir)
|
| 921 |
+
return err, err, {}, _build_flow_markdown(model_loaded=bool(model_bundle), inference_ready=False, task_states={}), *out
|
| 922 |
+
|
| 923 |
+
|
| 924 |
+
def _generate_with_flow(
|
| 925 |
+
model_bundle,
|
| 926 |
+
input_text,
|
| 927 |
+
temperature,
|
| 928 |
+
top_k,
|
| 929 |
+
repetition_penalty,
|
| 930 |
+
diversity_penalty,
|
| 931 |
+
num_steps,
|
| 932 |
+
clean_output,
|
| 933 |
+
):
|
| 934 |
+
out_text, status, meta = generate_from_ui(
|
| 935 |
+
model_bundle,
|
| 936 |
+
input_text,
|
| 937 |
+
temperature,
|
| 938 |
+
top_k,
|
| 939 |
+
repetition_penalty,
|
| 940 |
+
diversity_penalty,
|
| 941 |
+
num_steps,
|
| 942 |
+
clean_output,
|
| 943 |
+
)
|
| 944 |
+
flow = _build_flow_markdown(model_loaded=True, inference_ready=True, task_states={})
|
| 945 |
+
return out_text, status, meta, flow
|
| 946 |
|
| 947 |
|
| 948 |
CUSTOM_CSS = """
|
|
|
|
| 1001 |
init_msg = "Select a model and load." if checkpoint_map() else "No checkpoints found in ablation_results/ or results*/."
|
| 1002 |
load_status = gr.Markdown(init_msg)
|
| 1003 |
model_info = gr.JSON(label="Loaded Model Details")
|
| 1004 |
+
flow_box = gr.Markdown(_build_flow_markdown(model_loaded=False, inference_ready=False, task_states={}))
|
| 1005 |
+
task_states_view = gr.JSON(label="Task Execution State (side-by-side)")
|
| 1006 |
|
| 1007 |
with gr.Tabs():
|
| 1008 |
with gr.Tab("1) Task Runner"):
|
|
|
|
| 1023 |
value="analyze",
|
| 1024 |
label="Task 4 Phase",
|
| 1025 |
)
|
| 1026 |
+
gr.Markdown("**Task 5 Controls**")
|
| 1027 |
+
task5_lambda_min = gr.Slider(0.0, 3.0, value=0.0, step=0.1, label="Task5 λ min")
|
| 1028 |
+
task5_lambda_max = gr.Slider(0.0, 3.0, value=3.0, step=0.1, label="Task5 λ max")
|
| 1029 |
+
task5_lambda_step = gr.Slider(0.1, 1.0, value=0.5, step=0.1, label="Task5 λ step")
|
| 1030 |
+
task5_samples = gr.Slider(5, 200, value=50, step=5, label="Task5 sweep samples")
|
| 1031 |
run_all_btn = gr.Button("Run All 5 Tasks (Background)", variant="primary")
|
| 1032 |
track_bg_btn = gr.Button("Track Background Run")
|
| 1033 |
|
|
|
|
| 1117 |
)
|
| 1118 |
|
| 1119 |
generate_btn.click(
|
| 1120 |
+
fn=_generate_with_flow,
|
| 1121 |
inputs=[
|
| 1122 |
model_state,
|
| 1123 |
input_text,
|
|
|
|
| 1128 |
num_steps,
|
| 1129 |
clean_output,
|
| 1130 |
],
|
| 1131 |
+
outputs=[output_text, run_status, run_record, flow_box],
|
| 1132 |
)
|
| 1133 |
input_text.submit(
|
| 1134 |
+
fn=_generate_with_flow,
|
| 1135 |
inputs=[
|
| 1136 |
model_state,
|
| 1137 |
input_text,
|
|
|
|
| 1142 |
num_steps,
|
| 1143 |
clean_output,
|
| 1144 |
],
|
| 1145 |
+
outputs=[output_text, run_status, run_record, flow_box],
|
| 1146 |
)
|
| 1147 |
|
| 1148 |
run_single_btn.click(
|
| 1149 |
fn=_safe_run_single_task_and_refresh,
|
| 1150 |
+
inputs=[
|
| 1151 |
+
model_state, task_choice, analysis_output_dir, analysis_input, task4_phase,
|
| 1152 |
+
task5_lambda_min, task5_lambda_max, task5_lambda_step, task5_samples
|
| 1153 |
+
],
|
| 1154 |
outputs=[
|
| 1155 |
task_run_status,
|
| 1156 |
task_run_log,
|
| 1157 |
+
task_states_view,
|
| 1158 |
+
flow_box,
|
| 1159 |
task1_box,
|
| 1160 |
task2_box,
|
| 1161 |
task2_drift_img,
|
|
|
|
| 1170 |
)
|
| 1171 |
run_all_btn.click(
|
| 1172 |
fn=_safe_start_run_all_background,
|
| 1173 |
+
inputs=[
|
| 1174 |
+
model_state, analysis_output_dir, analysis_input, task4_phase, bg_job_state,
|
| 1175 |
+
task5_lambda_min, task5_lambda_max, task5_lambda_step, task5_samples
|
| 1176 |
+
],
|
| 1177 |
+
outputs=[task_run_status, task_run_log, bg_job_state, task_states_view, flow_box],
|
| 1178 |
)
|
| 1179 |
track_bg_btn.click(
|
| 1180 |
fn=_safe_poll_run_all_background,
|
|
|
|
| 1182 |
outputs=[
|
| 1183 |
task_run_status,
|
| 1184 |
task_run_log,
|
| 1185 |
+
task_states_view,
|
| 1186 |
+
flow_box,
|
| 1187 |
task1_box,
|
| 1188 |
task2_box,
|
| 1189 |
task2_drift_img,
|