bhsinghgrid commited on
Commit
5fd6ec8
·
verified ·
1 Parent(s): 5953b4e

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +215 -6
app.py CHANGED
@@ -4,6 +4,8 @@ import os
4
  import subprocess
5
  import sys
6
  import shutil
 
 
7
  from datetime import datetime
8
  from pathlib import Path
9
 
@@ -12,13 +14,22 @@ import torch
12
  from huggingface_hub import hf_hub_download
13
 
14
  from config import CONFIG
15
- from inference import _resolve_device, load_model, run_inference, _decode_clean, _decode_with_cleanup
 
 
 
 
 
 
 
 
16
  from model.tokenizer import SanskritSourceTokenizer, SanskritTargetTokenizer
17
 
18
 
19
  RESULTS_DIR = "generated_results"
20
  DEFAULT_ANALYSIS_OUT = "analysis_outputs/T4"
21
  os.makedirs(RESULTS_DIR, exist_ok=True)
 
22
 
23
  HF_DEFAULT_MODEL_REPO = os.environ.get("HF_DEFAULT_MODEL_REPO", "bhsinghgrid/DevaFlow")
24
  HF_DEFAULT_MODEL_FILE = os.environ.get("HF_DEFAULT_MODEL_FILE", "best_model.pt")
@@ -387,18 +398,198 @@ def _live_input_summary(model_bundle, input_text: str) -> str:
387
  )
388
 
389
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
390
  def run_single_task(model_bundle, task, output_dir, input_text, task4_phase):
391
  if not model_bundle:
392
  raise gr.Error("Load a model first.")
393
  code, log, used_bundled = _run_analysis_cmd(task, model_bundle["ckpt_path"], output_dir, input_text, task4_phase)
394
  if code != 0:
395
  _bundle_task_outputs(model_bundle, output_dir)
396
- log = f"{log}\n\n--- Live input summary ---\n{_live_input_summary(model_bundle, input_text)}"
397
  status = f"Task {task} fallback mode: bundled reports + live input analysis."
398
  else:
399
  if used_bundled:
400
  _bundle_task_outputs(model_bundle, output_dir)
401
- status = f"Task {task} loaded from bundled analysis outputs."
 
402
  else:
403
  status = f"Task {task} completed (exit={code})."
404
  return status, log
@@ -485,6 +676,7 @@ CUSTOM_CSS = """
485
 
486
  with gr.Blocks(title="Sanskrit Diffusion Client Demo", css=CUSTOM_CSS) as demo:
487
  model_state = gr.State(None)
 
488
 
489
  gr.Markdown(
490
  """
@@ -532,7 +724,8 @@ with gr.Blocks(title="Sanskrit Diffusion Client Demo", css=CUSTOM_CSS) as demo:
532
  value="analyze",
533
  label="Task 4 Phase",
534
  )
535
- run_all_btn = gr.Button("Run All 5 Tasks", variant="primary")
 
536
 
537
  with gr.Row():
538
  task_choice = gr.Dropdown(
@@ -651,9 +844,25 @@ with gr.Blocks(title="Sanskrit Diffusion Client Demo", css=CUSTOM_CSS) as demo:
651
  outputs=[task_run_status, task_run_log],
652
  )
653
  run_all_btn.click(
654
- fn=run_all_tasks,
655
  inputs=[model_state, analysis_output_dir, analysis_input, task4_phase],
656
- outputs=[task_run_status, task_run_log],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
657
  )
658
  refresh_outputs_btn.click(
659
  fn=refresh_task_outputs,
 
4
  import subprocess
5
  import sys
6
  import shutil
7
+ import threading
8
+ import uuid
9
  from datetime import datetime
10
  from pathlib import Path
11
 
 
14
  from huggingface_hub import hf_hub_download
15
 
16
  from config import CONFIG
17
+ from inference import (
18
+ _resolve_device,
19
+ load_model,
20
+ run_inference,
21
+ _decode_clean,
22
+ _decode_with_cleanup,
23
+ _iast_to_deva,
24
+ _compute_cer,
25
+ )
26
  from model.tokenizer import SanskritSourceTokenizer, SanskritTargetTokenizer
27
 
28
 
29
  RESULTS_DIR = "generated_results"
30
  DEFAULT_ANALYSIS_OUT = "analysis_outputs/T4"
31
  os.makedirs(RESULTS_DIR, exist_ok=True)
32
+ _BG_JOBS = {}
33
 
34
  HF_DEFAULT_MODEL_REPO = os.environ.get("HF_DEFAULT_MODEL_REPO", "bhsinghgrid/DevaFlow")
35
  HF_DEFAULT_MODEL_FILE = os.environ.get("HF_DEFAULT_MODEL_FILE", "best_model.pt")
 
398
  )
399
 
400
 
401
+ def _mini_tfidf_scores(text: str) -> dict:
402
+ tokens = [t for t in text.split() if t.strip()]
403
+ if not tokens:
404
+ return {}
405
+ corpus = [
406
+ "dharmo rakṣati rakṣitaḥ",
407
+ "satyameva jayate",
408
+ "ahiṃsā paramo dharmaḥ",
409
+ "vasudhaiva kuṭumbakam",
410
+ "yatra nāryastu pūjyante",
411
+ text,
412
+ ]
413
+ docs = [set([t for t in d.split() if t.strip()]) for d in corpus]
414
+ n = len(docs)
415
+ scores = {}
416
+ for tok in tokens:
417
+ df = sum(1 for d in docs if tok in d)
418
+ idf = (1.0 + (n + 1) / (1 + df))
419
+ scores[tok] = round(float(idf), 4)
420
+ return scores
421
+
422
+
423
+ def _run_single_prediction(model_bundle, text: str, cfg_override: dict | None = None) -> str:
424
+ cfg = copy.deepcopy(model_bundle["cfg"])
425
+ if cfg_override:
426
+ for k, v in cfg_override.items():
427
+ cfg["inference"][k] = v
428
+ src_tok = model_bundle["src_tok"]
429
+ tgt_tok = model_bundle["tgt_tok"]
430
+ device = torch.device(model_bundle["device"])
431
+ input_ids = torch.tensor(
432
+ [src_tok.encode(text.strip())[:cfg["model"]["max_seq_len"]]],
433
+ dtype=torch.long,
434
+ device=device,
435
+ )
436
+ out = run_inference(model_bundle["model"], input_ids, cfg)
437
+ return _decode_with_cleanup(tgt_tok, out[0].tolist(), text.strip(), cfg["inference"])
438
+
439
+
440
+ def _live_task_analysis(model_bundle, task: str, input_text: str) -> str:
441
+ text = input_text.strip()
442
+ if not text:
443
+ return "Live analysis skipped: empty input."
444
+ pred = _run_single_prediction(model_bundle, text)
445
+ toks = [t for t in pred.split() if t]
446
+ uniq = len(set(toks)) / max(1, len(toks))
447
+
448
+ if str(task) == "1":
449
+ t0 = datetime.now()
450
+ _ = _run_single_prediction(model_bundle, text, {"num_steps": 16})
451
+ t1 = datetime.now()
452
+ _ = _run_single_prediction(model_bundle, text, {"num_steps": 64})
453
+ t2 = datetime.now()
454
+ fast_ms = (t1 - t0).total_seconds() * 1000
455
+ full_ms = (t2 - t1).total_seconds() * 1000
456
+ return (
457
+ f"[Live Task1]\n"
458
+ f"Input: {text}\nPrediction: {pred}\n"
459
+ f"Token-length={len(toks)} unique-ratio={uniq:.3f}\n"
460
+ f"Latency proxy: 16-step={fast_ms:.1f}ms, 64-step={full_ms:.1f}ms"
461
+ )
462
+
463
+ if str(task) == "2":
464
+ tfidf = _mini_tfidf_scores(text)
465
+ top = sorted(tfidf.items(), key=lambda kv: kv[1], reverse=True)[:5]
466
+ return (
467
+ f"[Live Task2]\n"
468
+ f"Input: {text}\nPrediction: {pred}\n"
469
+ f"Token-length={len(toks)} unique-ratio={uniq:.3f}\n"
470
+ f"TF-IDF(top): {top}"
471
+ )
472
+
473
+ if str(task) == "3":
474
+ tfidf = _mini_tfidf_scores(text)
475
+ tf_mean = sum(tfidf.values()) / max(1, len(tfidf))
476
+ return (
477
+ f"[Live Task3]\n"
478
+ f"Input: {text}\nPrediction: {pred}\n"
479
+ f"Token-length={len(toks)} unique-ratio={uniq:.3f}\n"
480
+ f"Concept proxy: mean TF-IDF={tf_mean:.3f}"
481
+ )
482
+
483
+ if str(task) == "5":
484
+ ref = _iast_to_deva(text)
485
+ scales = [0.0, 0.5, 1.0, 1.5, 2.0]
486
+ rows = []
487
+ for s in scales:
488
+ cfg_map = {
489
+ "repetition_penalty": 1.1 + 0.15 * s,
490
+ "diversity_penalty": min(1.0, 0.10 * s),
491
+ }
492
+ out = _run_single_prediction(model_bundle, text, cfg_map)
493
+ cer = _compute_cer(out, ref)
494
+ rows.append((s, round(cer, 4), out[:48]))
495
+ return "[Live Task5]\n" + "\n".join([f"λ={r[0]:.1f} CER={r[1]:.4f} out={r[2]}" for r in rows])
496
+
497
+ return _live_input_summary(model_bundle, text)
498
+
499
+
500
+ def _bg_worker(job_id: str, model_bundle, output_dir: str, input_text: str, task4_phase: str):
501
+ tasks = ["1", "2", "3", "4", "5"]
502
+ failures = 0
503
+ logs = []
504
+ _BG_JOBS[job_id].update({"state": "running", "progress": 0, "failures": 0, "updated": datetime.now().isoformat()})
505
+ for idx, task in enumerate(tasks, start=1):
506
+ _BG_JOBS[job_id].update(
507
+ {
508
+ "state": f"running task {task}",
509
+ "progress": int((idx - 1) * 100 / len(tasks)),
510
+ "updated": datetime.now().isoformat(),
511
+ }
512
+ )
513
+ code, log, used_bundled = _run_analysis_cmd(task, model_bundle["ckpt_path"], output_dir, input_text, task4_phase)
514
+ logs.append(f"\n\n{'='*22} TASK {task} {'='*22}\n{log}")
515
+ if code != 0:
516
+ failures += 1
517
+ logs.append(f"\n[Live fallback]\n{_live_task_analysis(model_bundle, task, input_text)}\n")
518
+ elif used_bundled:
519
+ logs.append(f"\n[Live bundled summary]\n{_live_task_analysis(model_bundle, task, input_text)}\n")
520
+ _BG_JOBS[job_id].update(
521
+ {
522
+ "log": "".join(logs),
523
+ "failures": failures,
524
+ "progress": int(idx * 100 / len(tasks)),
525
+ "updated": datetime.now().isoformat(),
526
+ }
527
+ )
528
+ if failures:
529
+ _bundle_task_outputs(model_bundle, output_dir)
530
+ _BG_JOBS[job_id].update(
531
+ {
532
+ "state": "done",
533
+ "done": True,
534
+ "progress": 100,
535
+ "log": "".join(logs),
536
+ "failures": failures,
537
+ "updated": datetime.now().isoformat(),
538
+ }
539
+ )
540
+
541
+
542
+ def start_run_all_background(model_bundle, output_dir, input_text, task4_phase):
543
+ if not model_bundle:
544
+ raise gr.Error("Load a model first.")
545
+ os.makedirs(output_dir, exist_ok=True)
546
+ job_id = uuid.uuid4().hex[:10]
547
+ _BG_JOBS[job_id] = {
548
+ "state": "queued",
549
+ "progress": 0,
550
+ "log": "",
551
+ "failures": 0,
552
+ "done": False,
553
+ "output_dir": output_dir,
554
+ "created": datetime.now().isoformat(),
555
+ "updated": datetime.now().isoformat(),
556
+ }
557
+ th = threading.Thread(
558
+ target=_bg_worker,
559
+ args=(job_id, model_bundle, output_dir, input_text, task4_phase),
560
+ daemon=True,
561
+ )
562
+ th.start()
563
+ return f"Background run started. Job ID: {job_id}", f"Job {job_id} queued...", job_id
564
+
565
+
566
+ def poll_run_all_background(job_id, output_dir):
567
+ if not job_id or job_id not in _BG_JOBS:
568
+ msg = "No active background job. Start Run All 5 Tasks first."
569
+ empty = refresh_task_outputs(output_dir)
570
+ return msg, msg, *empty
571
+ j = _BG_JOBS[job_id]
572
+ status = (
573
+ f"Job {job_id} | state={j['state']} | progress={j['progress']}% | "
574
+ f"failures={j['failures']} | updated={j['updated']}"
575
+ )
576
+ outputs = refresh_task_outputs(output_dir)
577
+ return status, j.get("log", ""), *outputs
578
+
579
+
580
  def run_single_task(model_bundle, task, output_dir, input_text, task4_phase):
581
  if not model_bundle:
582
  raise gr.Error("Load a model first.")
583
  code, log, used_bundled = _run_analysis_cmd(task, model_bundle["ckpt_path"], output_dir, input_text, task4_phase)
584
  if code != 0:
585
  _bundle_task_outputs(model_bundle, output_dir)
586
+ log = f"{log}\n\n--- Live task analysis ---\n{_live_task_analysis(model_bundle, task, input_text)}"
587
  status = f"Task {task} fallback mode: bundled reports + live input analysis."
588
  else:
589
  if used_bundled:
590
  _bundle_task_outputs(model_bundle, output_dir)
591
+ log = f"{log}\n\n--- Live task analysis ---\n{_live_task_analysis(model_bundle, task, input_text)}"
592
+ status = f"Task {task} loaded from bundled analysis outputs + live analysis."
593
  else:
594
  status = f"Task {task} completed (exit={code})."
595
  return status, log
 
676
 
677
  with gr.Blocks(title="Sanskrit Diffusion Client Demo", css=CUSTOM_CSS) as demo:
678
  model_state = gr.State(None)
679
+ bg_job_state = gr.State("")
680
 
681
  gr.Markdown(
682
  """
 
724
  value="analyze",
725
  label="Task 4 Phase",
726
  )
727
+ run_all_btn = gr.Button("Run All 5 Tasks (Background)", variant="primary")
728
+ track_bg_btn = gr.Button("Track Background Run")
729
 
730
  with gr.Row():
731
  task_choice = gr.Dropdown(
 
844
  outputs=[task_run_status, task_run_log],
845
  )
846
  run_all_btn.click(
847
+ fn=start_run_all_background,
848
  inputs=[model_state, analysis_output_dir, analysis_input, task4_phase],
849
+ outputs=[task_run_status, task_run_log, bg_job_state],
850
+ )
851
+ track_bg_btn.click(
852
+ fn=poll_run_all_background,
853
+ inputs=[bg_job_state, analysis_output_dir],
854
+ outputs=[
855
+ task_run_status,
856
+ task_run_log,
857
+ task1_box,
858
+ task2_box,
859
+ task2_drift_img,
860
+ task2_attn_img,
861
+ task3_box,
862
+ task3_img,
863
+ task5_box,
864
+ task4_img,
865
+ ],
866
  )
867
  refresh_outputs_btn.click(
868
  fn=refresh_task_outputs,