Spaces:
Running
Running
Upload folder using huggingface_hub
Browse files
app.py
CHANGED
|
@@ -4,6 +4,8 @@ import os
|
|
| 4 |
import subprocess
|
| 5 |
import sys
|
| 6 |
import shutil
|
|
|
|
|
|
|
| 7 |
from datetime import datetime
|
| 8 |
from pathlib import Path
|
| 9 |
|
|
@@ -12,13 +14,22 @@ import torch
|
|
| 12 |
from huggingface_hub import hf_hub_download
|
| 13 |
|
| 14 |
from config import CONFIG
|
| 15 |
-
from inference import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
from model.tokenizer import SanskritSourceTokenizer, SanskritTargetTokenizer
|
| 17 |
|
| 18 |
|
| 19 |
RESULTS_DIR = "generated_results"
|
| 20 |
DEFAULT_ANALYSIS_OUT = "analysis_outputs/T4"
|
| 21 |
os.makedirs(RESULTS_DIR, exist_ok=True)
|
|
|
|
| 22 |
|
| 23 |
HF_DEFAULT_MODEL_REPO = os.environ.get("HF_DEFAULT_MODEL_REPO", "bhsinghgrid/DevaFlow")
|
| 24 |
HF_DEFAULT_MODEL_FILE = os.environ.get("HF_DEFAULT_MODEL_FILE", "best_model.pt")
|
|
@@ -387,18 +398,198 @@ def _live_input_summary(model_bundle, input_text: str) -> str:
|
|
| 387 |
)
|
| 388 |
|
| 389 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 390 |
def run_single_task(model_bundle, task, output_dir, input_text, task4_phase):
|
| 391 |
if not model_bundle:
|
| 392 |
raise gr.Error("Load a model first.")
|
| 393 |
code, log, used_bundled = _run_analysis_cmd(task, model_bundle["ckpt_path"], output_dir, input_text, task4_phase)
|
| 394 |
if code != 0:
|
| 395 |
_bundle_task_outputs(model_bundle, output_dir)
|
| 396 |
-
log = f"{log}\n\n--- Live
|
| 397 |
status = f"Task {task} fallback mode: bundled reports + live input analysis."
|
| 398 |
else:
|
| 399 |
if used_bundled:
|
| 400 |
_bundle_task_outputs(model_bundle, output_dir)
|
| 401 |
-
|
|
|
|
| 402 |
else:
|
| 403 |
status = f"Task {task} completed (exit={code})."
|
| 404 |
return status, log
|
|
@@ -485,6 +676,7 @@ CUSTOM_CSS = """
|
|
| 485 |
|
| 486 |
with gr.Blocks(title="Sanskrit Diffusion Client Demo", css=CUSTOM_CSS) as demo:
|
| 487 |
model_state = gr.State(None)
|
|
|
|
| 488 |
|
| 489 |
gr.Markdown(
|
| 490 |
"""
|
|
@@ -532,7 +724,8 @@ with gr.Blocks(title="Sanskrit Diffusion Client Demo", css=CUSTOM_CSS) as demo:
|
|
| 532 |
value="analyze",
|
| 533 |
label="Task 4 Phase",
|
| 534 |
)
|
| 535 |
-
run_all_btn = gr.Button("Run All 5 Tasks", variant="primary")
|
|
|
|
| 536 |
|
| 537 |
with gr.Row():
|
| 538 |
task_choice = gr.Dropdown(
|
|
@@ -651,9 +844,25 @@ with gr.Blocks(title="Sanskrit Diffusion Client Demo", css=CUSTOM_CSS) as demo:
|
|
| 651 |
outputs=[task_run_status, task_run_log],
|
| 652 |
)
|
| 653 |
run_all_btn.click(
|
| 654 |
-
fn=
|
| 655 |
inputs=[model_state, analysis_output_dir, analysis_input, task4_phase],
|
| 656 |
-
outputs=[task_run_status, task_run_log],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 657 |
)
|
| 658 |
refresh_outputs_btn.click(
|
| 659 |
fn=refresh_task_outputs,
|
|
|
|
| 4 |
import subprocess
|
| 5 |
import sys
|
| 6 |
import shutil
|
| 7 |
+
import threading
|
| 8 |
+
import uuid
|
| 9 |
from datetime import datetime
|
| 10 |
from pathlib import Path
|
| 11 |
|
|
|
|
| 14 |
from huggingface_hub import hf_hub_download
|
| 15 |
|
| 16 |
from config import CONFIG
|
| 17 |
+
from inference import (
|
| 18 |
+
_resolve_device,
|
| 19 |
+
load_model,
|
| 20 |
+
run_inference,
|
| 21 |
+
_decode_clean,
|
| 22 |
+
_decode_with_cleanup,
|
| 23 |
+
_iast_to_deva,
|
| 24 |
+
_compute_cer,
|
| 25 |
+
)
|
| 26 |
from model.tokenizer import SanskritSourceTokenizer, SanskritTargetTokenizer
|
| 27 |
|
| 28 |
|
| 29 |
RESULTS_DIR = "generated_results"
|
| 30 |
DEFAULT_ANALYSIS_OUT = "analysis_outputs/T4"
|
| 31 |
os.makedirs(RESULTS_DIR, exist_ok=True)
|
| 32 |
+
_BG_JOBS = {}
|
| 33 |
|
| 34 |
HF_DEFAULT_MODEL_REPO = os.environ.get("HF_DEFAULT_MODEL_REPO", "bhsinghgrid/DevaFlow")
|
| 35 |
HF_DEFAULT_MODEL_FILE = os.environ.get("HF_DEFAULT_MODEL_FILE", "best_model.pt")
|
|
|
|
| 398 |
)
|
| 399 |
|
| 400 |
|
| 401 |
+
def _mini_tfidf_scores(text: str) -> dict:
|
| 402 |
+
tokens = [t for t in text.split() if t.strip()]
|
| 403 |
+
if not tokens:
|
| 404 |
+
return {}
|
| 405 |
+
corpus = [
|
| 406 |
+
"dharmo rakṣati rakṣitaḥ",
|
| 407 |
+
"satyameva jayate",
|
| 408 |
+
"ahiṃsā paramo dharmaḥ",
|
| 409 |
+
"vasudhaiva kuṭumbakam",
|
| 410 |
+
"yatra nāryastu pūjyante",
|
| 411 |
+
text,
|
| 412 |
+
]
|
| 413 |
+
docs = [set([t for t in d.split() if t.strip()]) for d in corpus]
|
| 414 |
+
n = len(docs)
|
| 415 |
+
scores = {}
|
| 416 |
+
for tok in tokens:
|
| 417 |
+
df = sum(1 for d in docs if tok in d)
|
| 418 |
+
idf = (1.0 + (n + 1) / (1 + df))
|
| 419 |
+
scores[tok] = round(float(idf), 4)
|
| 420 |
+
return scores
|
| 421 |
+
|
| 422 |
+
|
| 423 |
+
def _run_single_prediction(model_bundle, text: str, cfg_override: dict | None = None) -> str:
|
| 424 |
+
cfg = copy.deepcopy(model_bundle["cfg"])
|
| 425 |
+
if cfg_override:
|
| 426 |
+
for k, v in cfg_override.items():
|
| 427 |
+
cfg["inference"][k] = v
|
| 428 |
+
src_tok = model_bundle["src_tok"]
|
| 429 |
+
tgt_tok = model_bundle["tgt_tok"]
|
| 430 |
+
device = torch.device(model_bundle["device"])
|
| 431 |
+
input_ids = torch.tensor(
|
| 432 |
+
[src_tok.encode(text.strip())[:cfg["model"]["max_seq_len"]]],
|
| 433 |
+
dtype=torch.long,
|
| 434 |
+
device=device,
|
| 435 |
+
)
|
| 436 |
+
out = run_inference(model_bundle["model"], input_ids, cfg)
|
| 437 |
+
return _decode_with_cleanup(tgt_tok, out[0].tolist(), text.strip(), cfg["inference"])
|
| 438 |
+
|
| 439 |
+
|
| 440 |
+
def _live_task_analysis(model_bundle, task: str, input_text: str) -> str:
|
| 441 |
+
text = input_text.strip()
|
| 442 |
+
if not text:
|
| 443 |
+
return "Live analysis skipped: empty input."
|
| 444 |
+
pred = _run_single_prediction(model_bundle, text)
|
| 445 |
+
toks = [t for t in pred.split() if t]
|
| 446 |
+
uniq = len(set(toks)) / max(1, len(toks))
|
| 447 |
+
|
| 448 |
+
if str(task) == "1":
|
| 449 |
+
t0 = datetime.now()
|
| 450 |
+
_ = _run_single_prediction(model_bundle, text, {"num_steps": 16})
|
| 451 |
+
t1 = datetime.now()
|
| 452 |
+
_ = _run_single_prediction(model_bundle, text, {"num_steps": 64})
|
| 453 |
+
t2 = datetime.now()
|
| 454 |
+
fast_ms = (t1 - t0).total_seconds() * 1000
|
| 455 |
+
full_ms = (t2 - t1).total_seconds() * 1000
|
| 456 |
+
return (
|
| 457 |
+
f"[Live Task1]\n"
|
| 458 |
+
f"Input: {text}\nPrediction: {pred}\n"
|
| 459 |
+
f"Token-length={len(toks)} unique-ratio={uniq:.3f}\n"
|
| 460 |
+
f"Latency proxy: 16-step={fast_ms:.1f}ms, 64-step={full_ms:.1f}ms"
|
| 461 |
+
)
|
| 462 |
+
|
| 463 |
+
if str(task) == "2":
|
| 464 |
+
tfidf = _mini_tfidf_scores(text)
|
| 465 |
+
top = sorted(tfidf.items(), key=lambda kv: kv[1], reverse=True)[:5]
|
| 466 |
+
return (
|
| 467 |
+
f"[Live Task2]\n"
|
| 468 |
+
f"Input: {text}\nPrediction: {pred}\n"
|
| 469 |
+
f"Token-length={len(toks)} unique-ratio={uniq:.3f}\n"
|
| 470 |
+
f"TF-IDF(top): {top}"
|
| 471 |
+
)
|
| 472 |
+
|
| 473 |
+
if str(task) == "3":
|
| 474 |
+
tfidf = _mini_tfidf_scores(text)
|
| 475 |
+
tf_mean = sum(tfidf.values()) / max(1, len(tfidf))
|
| 476 |
+
return (
|
| 477 |
+
f"[Live Task3]\n"
|
| 478 |
+
f"Input: {text}\nPrediction: {pred}\n"
|
| 479 |
+
f"Token-length={len(toks)} unique-ratio={uniq:.3f}\n"
|
| 480 |
+
f"Concept proxy: mean TF-IDF={tf_mean:.3f}"
|
| 481 |
+
)
|
| 482 |
+
|
| 483 |
+
if str(task) == "5":
|
| 484 |
+
ref = _iast_to_deva(text)
|
| 485 |
+
scales = [0.0, 0.5, 1.0, 1.5, 2.0]
|
| 486 |
+
rows = []
|
| 487 |
+
for s in scales:
|
| 488 |
+
cfg_map = {
|
| 489 |
+
"repetition_penalty": 1.1 + 0.15 * s,
|
| 490 |
+
"diversity_penalty": min(1.0, 0.10 * s),
|
| 491 |
+
}
|
| 492 |
+
out = _run_single_prediction(model_bundle, text, cfg_map)
|
| 493 |
+
cer = _compute_cer(out, ref)
|
| 494 |
+
rows.append((s, round(cer, 4), out[:48]))
|
| 495 |
+
return "[Live Task5]\n" + "\n".join([f"λ={r[0]:.1f} CER={r[1]:.4f} out={r[2]}" for r in rows])
|
| 496 |
+
|
| 497 |
+
return _live_input_summary(model_bundle, text)
|
| 498 |
+
|
| 499 |
+
|
| 500 |
+
def _bg_worker(job_id: str, model_bundle, output_dir: str, input_text: str, task4_phase: str):
|
| 501 |
+
tasks = ["1", "2", "3", "4", "5"]
|
| 502 |
+
failures = 0
|
| 503 |
+
logs = []
|
| 504 |
+
_BG_JOBS[job_id].update({"state": "running", "progress": 0, "failures": 0, "updated": datetime.now().isoformat()})
|
| 505 |
+
for idx, task in enumerate(tasks, start=1):
|
| 506 |
+
_BG_JOBS[job_id].update(
|
| 507 |
+
{
|
| 508 |
+
"state": f"running task {task}",
|
| 509 |
+
"progress": int((idx - 1) * 100 / len(tasks)),
|
| 510 |
+
"updated": datetime.now().isoformat(),
|
| 511 |
+
}
|
| 512 |
+
)
|
| 513 |
+
code, log, used_bundled = _run_analysis_cmd(task, model_bundle["ckpt_path"], output_dir, input_text, task4_phase)
|
| 514 |
+
logs.append(f"\n\n{'='*22} TASK {task} {'='*22}\n{log}")
|
| 515 |
+
if code != 0:
|
| 516 |
+
failures += 1
|
| 517 |
+
logs.append(f"\n[Live fallback]\n{_live_task_analysis(model_bundle, task, input_text)}\n")
|
| 518 |
+
elif used_bundled:
|
| 519 |
+
logs.append(f"\n[Live bundled summary]\n{_live_task_analysis(model_bundle, task, input_text)}\n")
|
| 520 |
+
_BG_JOBS[job_id].update(
|
| 521 |
+
{
|
| 522 |
+
"log": "".join(logs),
|
| 523 |
+
"failures": failures,
|
| 524 |
+
"progress": int(idx * 100 / len(tasks)),
|
| 525 |
+
"updated": datetime.now().isoformat(),
|
| 526 |
+
}
|
| 527 |
+
)
|
| 528 |
+
if failures:
|
| 529 |
+
_bundle_task_outputs(model_bundle, output_dir)
|
| 530 |
+
_BG_JOBS[job_id].update(
|
| 531 |
+
{
|
| 532 |
+
"state": "done",
|
| 533 |
+
"done": True,
|
| 534 |
+
"progress": 100,
|
| 535 |
+
"log": "".join(logs),
|
| 536 |
+
"failures": failures,
|
| 537 |
+
"updated": datetime.now().isoformat(),
|
| 538 |
+
}
|
| 539 |
+
)
|
| 540 |
+
|
| 541 |
+
|
| 542 |
+
def start_run_all_background(model_bundle, output_dir, input_text, task4_phase):
|
| 543 |
+
if not model_bundle:
|
| 544 |
+
raise gr.Error("Load a model first.")
|
| 545 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 546 |
+
job_id = uuid.uuid4().hex[:10]
|
| 547 |
+
_BG_JOBS[job_id] = {
|
| 548 |
+
"state": "queued",
|
| 549 |
+
"progress": 0,
|
| 550 |
+
"log": "",
|
| 551 |
+
"failures": 0,
|
| 552 |
+
"done": False,
|
| 553 |
+
"output_dir": output_dir,
|
| 554 |
+
"created": datetime.now().isoformat(),
|
| 555 |
+
"updated": datetime.now().isoformat(),
|
| 556 |
+
}
|
| 557 |
+
th = threading.Thread(
|
| 558 |
+
target=_bg_worker,
|
| 559 |
+
args=(job_id, model_bundle, output_dir, input_text, task4_phase),
|
| 560 |
+
daemon=True,
|
| 561 |
+
)
|
| 562 |
+
th.start()
|
| 563 |
+
return f"Background run started. Job ID: {job_id}", f"Job {job_id} queued...", job_id
|
| 564 |
+
|
| 565 |
+
|
| 566 |
+
def poll_run_all_background(job_id, output_dir):
|
| 567 |
+
if not job_id or job_id not in _BG_JOBS:
|
| 568 |
+
msg = "No active background job. Start Run All 5 Tasks first."
|
| 569 |
+
empty = refresh_task_outputs(output_dir)
|
| 570 |
+
return msg, msg, *empty
|
| 571 |
+
j = _BG_JOBS[job_id]
|
| 572 |
+
status = (
|
| 573 |
+
f"Job {job_id} | state={j['state']} | progress={j['progress']}% | "
|
| 574 |
+
f"failures={j['failures']} | updated={j['updated']}"
|
| 575 |
+
)
|
| 576 |
+
outputs = refresh_task_outputs(output_dir)
|
| 577 |
+
return status, j.get("log", ""), *outputs
|
| 578 |
+
|
| 579 |
+
|
| 580 |
def run_single_task(model_bundle, task, output_dir, input_text, task4_phase):
|
| 581 |
if not model_bundle:
|
| 582 |
raise gr.Error("Load a model first.")
|
| 583 |
code, log, used_bundled = _run_analysis_cmd(task, model_bundle["ckpt_path"], output_dir, input_text, task4_phase)
|
| 584 |
if code != 0:
|
| 585 |
_bundle_task_outputs(model_bundle, output_dir)
|
| 586 |
+
log = f"{log}\n\n--- Live task analysis ---\n{_live_task_analysis(model_bundle, task, input_text)}"
|
| 587 |
status = f"Task {task} fallback mode: bundled reports + live input analysis."
|
| 588 |
else:
|
| 589 |
if used_bundled:
|
| 590 |
_bundle_task_outputs(model_bundle, output_dir)
|
| 591 |
+
log = f"{log}\n\n--- Live task analysis ---\n{_live_task_analysis(model_bundle, task, input_text)}"
|
| 592 |
+
status = f"Task {task} loaded from bundled analysis outputs + live analysis."
|
| 593 |
else:
|
| 594 |
status = f"Task {task} completed (exit={code})."
|
| 595 |
return status, log
|
|
|
|
| 676 |
|
| 677 |
with gr.Blocks(title="Sanskrit Diffusion Client Demo", css=CUSTOM_CSS) as demo:
|
| 678 |
model_state = gr.State(None)
|
| 679 |
+
bg_job_state = gr.State("")
|
| 680 |
|
| 681 |
gr.Markdown(
|
| 682 |
"""
|
|
|
|
| 724 |
value="analyze",
|
| 725 |
label="Task 4 Phase",
|
| 726 |
)
|
| 727 |
+
run_all_btn = gr.Button("Run All 5 Tasks (Background)", variant="primary")
|
| 728 |
+
track_bg_btn = gr.Button("Track Background Run")
|
| 729 |
|
| 730 |
with gr.Row():
|
| 731 |
task_choice = gr.Dropdown(
|
|
|
|
| 844 |
outputs=[task_run_status, task_run_log],
|
| 845 |
)
|
| 846 |
run_all_btn.click(
|
| 847 |
+
fn=start_run_all_background,
|
| 848 |
inputs=[model_state, analysis_output_dir, analysis_input, task4_phase],
|
| 849 |
+
outputs=[task_run_status, task_run_log, bg_job_state],
|
| 850 |
+
)
|
| 851 |
+
track_bg_btn.click(
|
| 852 |
+
fn=poll_run_all_background,
|
| 853 |
+
inputs=[bg_job_state, analysis_output_dir],
|
| 854 |
+
outputs=[
|
| 855 |
+
task_run_status,
|
| 856 |
+
task_run_log,
|
| 857 |
+
task1_box,
|
| 858 |
+
task2_box,
|
| 859 |
+
task2_drift_img,
|
| 860 |
+
task2_attn_img,
|
| 861 |
+
task3_box,
|
| 862 |
+
task3_img,
|
| 863 |
+
task5_box,
|
| 864 |
+
task4_img,
|
| 865 |
+
],
|
| 866 |
)
|
| 867 |
refresh_outputs_btn.click(
|
| 868 |
fn=refresh_task_outputs,
|