import copy import json import os import subprocess import sys import shutil import threading import uuid import time from datetime import datetime from pathlib import Path import gradio as gr import torch from huggingface_hub import hf_hub_download from config import CONFIG from inference import ( _resolve_device, load_model, run_inference, _decode_clean, _decode_with_cleanup, _iast_to_deva, _compute_cer, ) from model.tokenizer import SanskritSourceTokenizer, SanskritTargetTokenizer RESULTS_DIR = "generated_results" DEFAULT_ANALYSIS_OUT = "analysis_outputs/outputs_all_models_20260325/T4" os.makedirs(RESULTS_DIR, exist_ok=True) _BG_JOBS = {} _CHECKPOINT_CACHE = None try: import mlflow except Exception: mlflow = None _MLFLOW_READY = False FLOW_STEPS = [ "Start", "Load Model (checkpoint/config/device/eval)", "Load Tokenizers", "Input (IAST)", "Source Tokenization", "Encoder (run once)", "KV Cache prepared", "Initialize x_T (MASK)", "Diffusion loop (T→0, with Task2/Task3 hooks)", "Final x0", "Decode to Devanagari", "Evaluation/Tasks (Task4/Task5)", ] def _setup_mlflow_once(): global _MLFLOW_READY if _MLFLOW_READY: return if mlflow is None: return try: tracking_uri = os.environ.get("MLFLOW_TRACKING_URI", "file:/tmp/mlruns") experiment = os.environ.get("MLFLOW_EXPERIMENT_NAME", "hf-space-sanskrit-d3pm") mlflow.set_tracking_uri(tracking_uri) mlflow.set_experiment(experiment) _MLFLOW_READY = True except Exception: _MLFLOW_READY = False def _mlflow_event(run_name: str, params: dict | None = None, metrics: dict | None = None, tags: dict | None = None): _setup_mlflow_once() if not _MLFLOW_READY or mlflow is None: return try: with mlflow.start_run(run_name=run_name, nested=False): if tags: mlflow.set_tags({k: str(v) for k, v in tags.items()}) if params: mlflow.log_params({k: (v if isinstance(v, (int, float, str, bool)) else str(v)) for k, v in params.items()}) if metrics: mlflow.log_metrics({k: float(v) for k, v in metrics.items()}) except Exception: pass def _build_flow_markdown(model_loaded=False, inference_ready=False, task_states=None): lines = ["### Execution Flow"] task_states = task_states or {} any_task_activity = any(v != "pending" for v in task_states.values()) if task_states else False for i, step in enumerate(FLOW_STEPS, start=1): status = "⬜" if model_loaded and i <= 3: status = "✅" if (inference_ready or model_loaded) and i <= 11: status = "✅" if i == 12 and any_task_activity: status = "✅" lines.append(f"{status} {i}. {step}") if task_states: lines.append("") lines.append("### Task Status") for k in ["1", "2", "3", "4", "5"]: v = task_states.get(k, "pending") icon = "✅" if v == "done" else ("🔄" if v.startswith("running") else ("❌" if v == "failed" else "⬜")) lines.append(f"{icon} Task {k}: {v}") return "\n".join(lines) def _task5_cfg(lambda_min, lambda_max, lambda_step, task5_samples): lo = float(lambda_min) hi = float(lambda_max) st = max(0.1, float(lambda_step)) if hi < lo: lo, hi = hi, lo vals = [] cur = lo while cur <= hi + 1e-9 and len(vals) < 30: vals.append(round(cur, 2)) cur += st if not vals: vals = [0.0, 0.5, 1.0, 1.5, 2.0, 3.0] return {"scales": vals, "samples": max(5, int(task5_samples))} HF_DEFAULT_MODEL_REPO = os.environ.get("HF_DEFAULT_MODEL_REPO", "bhsinghgrid/DevaFlow") HF_DEFAULT_MODEL_FILE = os.environ.get("HF_DEFAULT_MODEL_FILE", "best_model.pt") HF_CHECKPOINT_REPO = os.environ.get("HF_CHECKPOINT_REPO", "bhsinghgrid/devflow2") HF_CHECKPOINT_FILE = os.environ.get("HF_CHECKPOINT_FILE", "best_model.pt") HF_MODEL_REPOS = [ repo.strip() for repo in os.environ.get("HF_MODEL_REPOS", "bhsinghgrid/DevaFlow,bhsinghgrid/devflow2").split(",") if repo.strip() ] HF_DEFAULT_MODEL_TYPE = os.environ.get("HF_DEFAULT_MODEL_TYPE", "d3pm_cross_attention") HF_DEFAULT_INCLUDE_NEG = os.environ.get("HF_DEFAULT_INCLUDE_NEG", "false") HF_DEFAULT_NUM_STEPS = os.environ.get("HF_DEFAULT_NUM_STEPS") HF_DEFAULT_MODEL_SETTINGS_FILE = os.environ.get("HF_DEFAULT_MODEL_SETTINGS_FILE", "model_settings.json") def _download_hf_model_settings(): try: cache_dir = Path(".hf_model_cache") cache_dir.mkdir(parents=True, exist_ok=True) settings_path = hf_hub_download( repo_id=HF_DEFAULT_MODEL_REPO, filename=HF_DEFAULT_MODEL_SETTINGS_FILE, local_dir=str(cache_dir), ) with open(settings_path, "r", encoding="utf-8") as f: data = json.load(f) return data if isinstance(data, dict) else {} except Exception: return {} HF_DEFAULT_SETTINGS = _download_hf_model_settings() def _repo_cache_dir(repo_id: str) -> Path: safe = repo_id.replace("/", "__") path = Path(".hf_model_cache") / safe path.mkdir(parents=True, exist_ok=True) return path def _download_hf_checkpoint(repo_id: str, filename: str = "best_model.pt"): try: cache_dir = _repo_cache_dir(repo_id) return hf_hub_download( repo_id=repo_id, filename=filename, local_dir=str(cache_dir), ) except Exception: return None def _download_hf_settings_for_repo(repo_id: str): try: cache_dir = _repo_cache_dir(repo_id) settings_path = hf_hub_download( repo_id=repo_id, filename=HF_DEFAULT_MODEL_SETTINGS_FILE, local_dir=str(cache_dir), ) with open(settings_path, "r", encoding="utf-8") as f: data = json.load(f) return data if isinstance(data, dict) else {} except Exception: return {} def discover_checkpoints(): global _CHECKPOINT_CACHE if _CHECKPOINT_CACHE is not None: return list(_CHECKPOINT_CACHE) found = [] local_roots = [ ("ablation_results", "cross_attention"), (os.path.join("ablation_results", "encoder_decoder"), "encoder_decoder"), ("results7", "other"), ("results", "other"), ] for root, family in local_roots: if not os.path.isdir(root): continue for entry in sorted(os.listdir(root)): ckpt = os.path.join(root, entry, "best_model.pt") if not os.path.exists(ckpt): continue found.append( { "label": f"{entry} [{root}]", "path": ckpt, "experiment": entry, "root": root, "family": family, } ) for repo_id in HF_MODEL_REPOS: settings = _download_hf_settings_for_repo(repo_id) model_type = settings.get("model_type", "") family = "encoder_decoder" if model_type == "d3pm_encoder_decoder" else "cross_attention" num_steps = settings.get("num_steps", HF_DEFAULT_NUM_STEPS) step_label = f"T{num_steps}" if num_steps else "HF" found.append( { "label": f"{repo_id} [{family}:{step_label}]", "path": None, "experiment": step_label, "root": "hf", "family": family, "repo_id": repo_id, "repo_file": HF_CHECKPOINT_FILE, "hf_settings": settings, } ) _CHECKPOINT_CACHE = list(found) return list(found) def _guess_analysis_dir(experiment: str, ckpt_path: str, family: str = "cross_attention", settings: dict | None = None) -> str: settings = settings or {} base = Path("analysis_outputs") packaged = base / "outputs_all_models_20260325" step = None if experiment and experiment.startswith("T") and experiment[1:].isdigit(): step = experiment elif settings.get("num_steps"): step = f"T{int(settings['num_steps'])}" else: for part in Path(ckpt_path or "").parts: if part.startswith("T") and part[1:].isdigit(): step = part break if packaged.exists() and step: if family == "encoder_decoder" and (packaged / "encoder_decoder" / step).is_dir(): return str(packaged / "encoder_decoder" / step) if (packaged / step).is_dir(): return str(packaged / step) if base.exists(): if family == "encoder_decoder" and step and (base / "encoder_decoder" / step).is_dir(): return str(base / "encoder_decoder" / step) if step and (base / step).is_dir(): return str(base / step) if (base / "T4").is_dir(): return str(base / "T4") return os.path.join("analysis", "outputs_ui", experiment or "default") def checkpoint_map(): return {item["label"]: item for item in discover_checkpoints()} def default_checkpoint_label(): cps = discover_checkpoints() if not cps: return None for item in cps: path = item.get("path") if path and path.endswith("ablation_results/T4/best_model.pt"): return item["label"] for item in cps: if item.get("repo_id") == HF_DEFAULT_MODEL_REPO: return item["label"] return cps[0]["label"] def infer_model_type(experiment_name: str, root: str = "") -> str: if root == "ablation_results": return "d3pm_cross_attention" if experiment_name.startswith("d3pm_cross_attention"): return "d3pm_cross_attention" if experiment_name.startswith("d3pm_encoder_decoder"): return "d3pm_encoder_decoder" if experiment_name.startswith("baseline_cross_attention"): return "baseline_cross_attention" if experiment_name.startswith("baseline_encoder_decoder"): return "baseline_encoder_decoder" return CONFIG["model_type"] def infer_include_negative(experiment_name: str, root: str = "") -> bool: if root == "ablation_results": return False if "_neg_True" in experiment_name: return True if "_neg_False" in experiment_name: return False return CONFIG["data"]["include_negative_examples"] def build_runtime_cfg(ckpt_path: str, item: dict | None = None): item = item or {} if item.get("root") == "hf": experiment = item.get("experiment", "hf") root = "hf" hf_settings = item.get("hf_settings", {}) else: experiment = os.path.basename(os.path.dirname(ckpt_path)) root = os.path.basename(os.path.dirname(os.path.dirname(ckpt_path))) hf_settings = {} cfg = copy.deepcopy(CONFIG) if root == "hf": cfg["model_type"] = ( hf_settings.get("model_type") or os.environ.get("HF_DEFAULT_MODEL_TYPE") or HF_DEFAULT_SETTINGS.get("model_type") or HF_DEFAULT_MODEL_TYPE ) include_neg_raw = str( hf_settings.get( "include_negative_examples", os.environ.get( "HF_DEFAULT_INCLUDE_NEG", HF_DEFAULT_SETTINGS.get("include_negative_examples", HF_DEFAULT_INCLUDE_NEG), ), ) ) cfg["data"]["include_negative_examples"] = include_neg_raw.lower() == "true" t_raw = ( hf_settings.get("num_steps") or os.environ.get("HF_DEFAULT_NUM_STEPS") or HF_DEFAULT_SETTINGS.get("num_steps") or HF_DEFAULT_NUM_STEPS ) if t_raw: t_val = int(t_raw) cfg["model"]["diffusion_steps"] = t_val cfg["inference"]["num_steps"] = t_val else: cfg["model_type"] = infer_model_type(experiment, root=root) cfg["data"]["include_negative_examples"] = infer_include_negative(experiment, root=root) if root == "ablation_results" and experiment.startswith("T") and experiment[1:].isdigit(): t_val = int(experiment[1:]) cfg["model"]["diffusion_steps"] = t_val cfg["inference"]["num_steps"] = t_val device = _resolve_device(cfg.get("training", {}).get("device", "cpu")) return cfg, device, experiment def _build_tokenizers(cfg): src_tok = SanskritSourceTokenizer( vocab_size=cfg["model"].get("src_vocab_size", 16000), max_len=cfg["model"]["max_seq_len"], ) tgt_tok = SanskritTargetTokenizer( vocab_size=cfg["model"].get("tgt_vocab_size", 16000), max_len=cfg["model"]["max_seq_len"], ) return src_tok, tgt_tok def load_selected_model(checkpoint_label): mapping = checkpoint_map() if not mapping: raise gr.Error("No checkpoints found. Add models under ablation_results/ or results*/.") if not checkpoint_label: checkpoint_label = default_checkpoint_label() if checkpoint_label not in mapping: raise gr.Error("Selected checkpoint not found. Click refresh.") item = mapping[checkpoint_label] ckpt_path = item.get("path") if item.get("root") == "hf": ckpt_path = _download_hf_checkpoint(item["repo_id"], item.get("repo_file", HF_CHECKPOINT_FILE)) if not ckpt_path or not os.path.exists(ckpt_path): raise gr.Error(f"Failed to download checkpoint from {item['repo_id']}.") item["path"] = ckpt_path cfg, device, experiment = build_runtime_cfg(ckpt_path, item=item) model, cfg = load_model(ckpt_path, cfg, device) src_tok, tgt_tok = _build_tokenizers(cfg) bundle = { "ckpt_path": ckpt_path, "experiment": experiment, "family": item.get("family", "cross_attention"), "repo_id": item.get("repo_id"), "hf_settings": item.get("hf_settings", {}), "device": str(device), "cfg": cfg, "model": model, "src_tok": src_tok, "tgt_tok": tgt_tok, } model_info = { "checkpoint": ckpt_path, "experiment": experiment, "model_type": cfg["model_type"], "include_negatives": cfg["data"]["include_negative_examples"], "device": str(device), "max_seq_len": cfg["model"]["max_seq_len"], "diffusion_steps": cfg["model"]["diffusion_steps"], "inference_steps": cfg["inference"]["num_steps"], "d_model": cfg["model"]["d_model"], "n_layers": cfg["model"]["n_layers"], "n_heads": cfg["model"]["n_heads"], "family": item.get("family", "cross_attention"), "repo_id": item.get("repo_id"), } status = f"Loaded `{experiment}` on `{device}` (`{cfg['model_type']}`)" suggested_out = _guess_analysis_dir( experiment, ckpt_path, family=item.get("family", "cross_attention"), settings=item.get("hf_settings", {}), ) return bundle, status, model_info, cfg["inference"]["num_steps"], suggested_out def apply_preset(preset_name): presets = { "Manual": (0.70, 40, 1.20, 0.0), "Literal": (0.60, 20, 1.25, 0.0), "Balanced": (0.70, 40, 1.20, 0.0), "Creative": (0.90, 80, 1.05, 0.2), } return presets.get(preset_name, presets["Balanced"]) def save_generation(experiment, record): ts = datetime.now().strftime("%Y%m%d") path = os.path.join(RESULTS_DIR, f"{experiment}_ui_{ts}.json") existing = [] if os.path.exists(path): with open(path, "r", encoding="utf-8") as f: existing = json.load(f) existing.append(record) with open(path, "w", encoding="utf-8") as f: json.dump(existing, f, ensure_ascii=False, indent=2) return path def generate_from_ui( model_bundle, input_text, temperature, top_k, repetition_penalty, diversity_penalty, num_steps, clean_output, ): if not model_bundle: raise gr.Error("Load a model first.") if not input_text.strip(): raise gr.Error("Enter input text first.") t0 = time.perf_counter() cfg = copy.deepcopy(model_bundle["cfg"]) cfg["inference"]["temperature"] = float(temperature) cfg["inference"]["top_k"] = int(top_k) cfg["inference"]["repetition_penalty"] = float(repetition_penalty) cfg["inference"]["diversity_penalty"] = float(diversity_penalty) cfg["inference"]["num_steps"] = int(num_steps) src_tok = model_bundle["src_tok"] tgt_tok = model_bundle["tgt_tok"] device = torch.device(model_bundle["device"]) input_ids = torch.tensor( [src_tok.encode(input_text.strip())[:cfg["model"]["max_seq_len"]]], dtype=torch.long, device=device, ) out = run_inference(model_bundle["model"], input_ids, cfg) # Use the exact inference decode/cleanup logic for parity with inference.py raw_output_text = _decode_clean(tgt_tok, out[0].tolist()) if clean_output: output_text = _decode_with_cleanup( tgt_tok, out[0].tolist(), input_text.strip(), cfg["inference"] ) else: output_text = raw_output_text if not output_text: output_text = "(empty output)" record = { "timestamp": datetime.now().isoformat(timespec="seconds"), "experiment": model_bundle["experiment"], "checkpoint": model_bundle["ckpt_path"], "input_text": input_text, "raw_output_text": raw_output_text, "output_text": output_text, "temperature": float(temperature), "top_k": int(top_k), "repetition_penalty": float(repetition_penalty), "diversity_penalty": float(diversity_penalty), "num_steps": int(num_steps), "clean_output": bool(clean_output), } log_path = save_generation(model_bundle["experiment"], record) latency_ms = (time.perf_counter() - t0) * 1000.0 toks = [t for t in output_text.split() if t] uniq = len(set(toks)) / max(1, len(toks)) _mlflow_event( run_name="space_inference", params={ "experiment": model_bundle["experiment"], "checkpoint": model_bundle["ckpt_path"], "temperature": float(temperature), "top_k": int(top_k), "repetition_penalty": float(repetition_penalty), "diversity_penalty": float(diversity_penalty), "num_steps": int(num_steps), "clean_output": bool(clean_output), }, metrics={ "latency_ms": latency_ms, "input_char_len": len(input_text.strip()), "output_char_len": len(output_text), "output_token_len": len(toks), "output_unique_ratio": uniq, }, tags={"source": "hf_space"}, ) status = f"Inference done. Saved: `{log_path}`" return output_text, status, record def _resolve_analysis_script() -> Path | None: candidates = [ Path("analysis") / "run_analysis.py", Path("final_folder") / "analysis" / "run_analysis.py", Path("deploy_ready") / "space_repo" / "analysis" / "run_analysis.py", ] for p in candidates: if p.exists(): return p return None def _run_analysis_cmd(task, ckpt_path, output_dir, input_text="dharmo rakṣati rakṣitaḥ", phase="analyze", task5_samples=50): os.makedirs(output_dir, exist_ok=True) script = _resolve_analysis_script() if script is None: bundled = Path("analysis_outputs") if bundled.exists(): return 0, "Analysis runner not bundled; using packaged analysis_outputs.", True return 2, "Analysis runner missing and no bundled analysis_outputs found.", False # Space-safe Task4 fallback: if ablation models don't exist, bootstrap them # from currently selected checkpoint so Task4 can still execute end-to-end. if str(task) == "4" and phase == "analyze": for t in (4, 8, 16, 32, 64): t_dir = Path("ablation_results") / f"T{t}" t_dir.mkdir(parents=True, exist_ok=True) dst = t_dir / "best_model.pt" if not dst.exists(): try: os.symlink(os.path.abspath(ckpt_path), str(dst)) except Exception: import shutil shutil.copy2(ckpt_path, str(dst)) cmd = [ sys.executable, str(script), "--task", str(task), "--checkpoint", ckpt_path, "--output_dir", output_dir, ] if str(task) == "2" or str(task) == "all": cmd.extend(["--input", input_text]) if str(task) == "4": cmd.extend(["--phase", phase]) if str(task) == "5": cmd.extend(["--task5_samples", str(int(task5_samples))]) env = os.environ.copy() env.setdefault("HF_HOME", "/tmp/hf_home") env.setdefault("HF_DATASETS_CACHE", "/tmp/hf_datasets") env.setdefault("HF_HUB_CACHE", "/tmp/hf_hub") timeout_map = {"1": 120, "2": 180, "3": 240, "4": 300, "5": 240} timeout_s = int(os.environ.get("TASK_TIMEOUT_S", timeout_map.get(str(task), 180))) try: proc = subprocess.run(cmd, capture_output=True, text=True, env=env, timeout=timeout_s) log = f"$ {' '.join(cmd)}\n\n{proc.stdout}\n{proc.stderr}" return proc.returncode, log, False except subprocess.TimeoutExpired as e: out = e.stdout or "" err = e.stderr or "" log = f"$ {' '.join(cmd)}\n\n[timeout after {timeout_s}s]\n{out}\n{err}" return 124, log, False def _bundle_task_outputs(model_bundle, output_dir): src_dir = _guess_analysis_dir( model_bundle.get("experiment", ""), model_bundle.get("ckpt_path", ""), family=model_bundle.get("family", "cross_attention"), settings=model_bundle.get("hf_settings", {}), ) if not os.path.isdir(src_dir): return os.makedirs(output_dir, exist_ok=True) for name in os.listdir(src_dir): src = os.path.join(src_dir, name) dst = os.path.join(output_dir, name) if os.path.isfile(src): shutil.copy2(src, dst) def _live_input_summary(model_bundle, input_text: str) -> str: if not input_text.strip(): return "No input text provided." cfg = copy.deepcopy(model_bundle["cfg"]) src_tok = model_bundle["src_tok"] tgt_tok = model_bundle["tgt_tok"] device = torch.device(model_bundle["device"]) inp = torch.tensor([src_tok.encode(input_text.strip())[:cfg["model"]["max_seq_len"]]], dtype=torch.long, device=device) out = run_inference(model_bundle["model"], inp, cfg) pred = _decode_with_cleanup(tgt_tok, out[0].tolist(), input_text.strip(), cfg["inference"]) toks = pred.split() uniq = len(set(toks)) / max(1, len(toks)) return ( f"Live input: {input_text}\n" f"Prediction: {pred}\n" f"Length(tokens): {len(toks)}\n" f"Unique-token ratio: {uniq:.3f}" ) def _mini_tfidf_scores(text: str) -> dict: tokens = [t for t in text.split() if t.strip()] if not tokens: return {} corpus = [ "dharmo rakṣati rakṣitaḥ", "satyameva jayate", "ahiṃsā paramo dharmaḥ", "vasudhaiva kuṭumbakam", "yatra nāryastu pūjyante", text, ] docs = [set([t for t in d.split() if t.strip()]) for d in corpus] n = len(docs) scores = {} for tok in tokens: df = sum(1 for d in docs if tok in d) idf = (1.0 + (n + 1) / (1 + df)) scores[tok] = round(float(idf), 4) return scores def _run_single_prediction(model_bundle, text: str, cfg_override: dict | None = None) -> str: cfg = copy.deepcopy(model_bundle["cfg"]) if cfg_override: for k, v in cfg_override.items(): cfg["inference"][k] = v src_tok = model_bundle["src_tok"] tgt_tok = model_bundle["tgt_tok"] device = torch.device(model_bundle["device"]) input_ids = torch.tensor( [src_tok.encode(text.strip())[:cfg["model"]["max_seq_len"]]], dtype=torch.long, device=device, ) out = run_inference(model_bundle["model"], input_ids, cfg) return _decode_with_cleanup(tgt_tok, out[0].tolist(), text.strip(), cfg["inference"]) def _live_task_analysis(model_bundle, task: str, input_text: str, task5_cfg: dict | None = None) -> str: text = input_text.strip() if not text: return "Live analysis skipped: empty input." pred = _run_single_prediction(model_bundle, text) toks = [t for t in pred.split() if t] uniq = len(set(toks)) / max(1, len(toks)) if str(task) == "1": t0 = datetime.now() _ = _run_single_prediction(model_bundle, text, {"num_steps": 16}) t1 = datetime.now() _ = _run_single_prediction(model_bundle, text, {"num_steps": 64}) t2 = datetime.now() fast_ms = (t1 - t0).total_seconds() * 1000 full_ms = (t2 - t1).total_seconds() * 1000 return ( f"[Live Task1]\n" f"Input: {text}\nPrediction: {pred}\n" f"Token-length={len(toks)} unique-ratio={uniq:.3f}\n" f"Latency proxy: 16-step={fast_ms:.1f}ms, 64-step={full_ms:.1f}ms" ) if str(task) == "2": # Live diffusion proxy: run same input with multiple step counts and # show semantic drift to final output while task is running. base_steps = int(model_bundle["cfg"]["inference"].get("num_steps", 64)) step_grid = sorted(set([max(1, base_steps), max(1, base_steps // 2), max(1, base_steps // 4), 1]), reverse=True) traj = [] final_out = None for s in step_grid: out_s = _run_single_prediction(model_bundle, text, {"num_steps": int(s)}) if s == 1: final_out = out_s traj.append((s, out_s)) if final_out is None and traj: final_out = traj[-1][1] drift_rows = [] for s, out_s in traj: d = _compute_cer(out_s, final_out or out_s) drift_rows.append((s, round(d, 4), out_s[:56])) tfidf = _mini_tfidf_scores(text) top = sorted(tfidf.items(), key=lambda kv: kv[1], reverse=True)[:5] traj_txt = "\n".join([f"steps={s:>3d} drift_to_final={d:.4f} out={o}" for s, d, o in drift_rows]) return ( f"[Live Task2]\n" f"Input: {text}\nPrediction: {pred}\n" f"Token-length={len(toks)} unique-ratio={uniq:.3f}\n" f"TF-IDF(top): {top}\n" f"Diffusion trajectory (live):\n{traj_txt}" ) if str(task) == "3": tfidf = _mini_tfidf_scores(text) tf_mean = sum(tfidf.values()) / max(1, len(tfidf)) return ( f"[Live Task3]\n" f"Input: {text}\nPrediction: {pred}\n" f"Token-length={len(toks)} unique-ratio={uniq:.3f}\n" f"Concept proxy: mean TF-IDF={tf_mean:.3f}" ) if str(task) == "5": ref = _iast_to_deva(text) scales = (task5_cfg or {}).get("scales", [0.0, 0.5, 1.0, 1.5, 2.0]) rows = [] for s in scales: cfg_map = { "repetition_penalty": 1.1 + 0.15 * s, "diversity_penalty": min(1.0, 0.10 * s), } out = _run_single_prediction(model_bundle, text, cfg_map) cer = _compute_cer(out, ref) rows.append((s, round(cer, 4), out[:48])) return "[Live Task5]\n" + "\n".join([f"λ={r[0]:.1f} CER={r[1]:.4f} out={r[2]}" for r in rows]) return _live_input_summary(model_bundle, text) def _run_quick_task(task, model_bundle, input_text, task5_cfg): log = ( f"[Quick Mode] Task {task}\n" f"Heavy analysis runner skipped for speed.\n\n" f"{_live_task_analysis(model_bundle, task, input_text, task5_cfg)}" ) return 0, log, False def _bg_worker( job_id: str, model_bundle, output_dir: str, input_text: str, task4_phase: str, task5_cfg: dict, quick_mode: bool, ): tasks = ["1", "2", "3", "4", "5"] failures = 0 logs = [] run_start = time.perf_counter() _BG_JOBS[job_id].update({"state": "running", "progress": 0, "failures": 0, "updated": datetime.now().isoformat()}) for idx, task in enumerate(tasks, start=1): _BG_JOBS[job_id]["task_states"][task] = "running" _BG_JOBS[job_id].update( { "state": f"running task {task}", "progress": int((idx - 1) * 100 / len(tasks)), "updated": datetime.now().isoformat(), } ) try: if quick_mode: code, log, used_bundled = _run_quick_task(task, model_bundle, input_text, task5_cfg) else: code, log, used_bundled = _run_analysis_cmd( task, model_bundle["ckpt_path"], output_dir, input_text, task4_phase, task5_cfg.get("samples", 50), ) logs.append(f"\n\n{'='*22} TASK {task} {'='*22}\n{log}") if code != 0: failures += 1 try: logs.append(f"\n[Live fallback]\n{_live_task_analysis(model_bundle, task, input_text, task5_cfg)}\n") _BG_JOBS[job_id]["task_states"][task] = "done(live-fast)" except Exception as live_e: _BG_JOBS[job_id]["task_states"][task] = "failed" logs.append(f"\n[Live fallback failed]\n{live_e}\n") elif used_bundled: _BG_JOBS[job_id]["task_states"][task] = "done(bundled)" logs.append(f"\n[Live bundled summary]\n{_live_task_analysis(model_bundle, task, input_text, task5_cfg)}\n") else: _BG_JOBS[job_id]["task_states"][task] = "done" except Exception as e: failures += 1 _BG_JOBS[job_id]["task_states"][task] = "failed" logs.append(f"\n\n{'='*22} TASK {task} {'='*22}\n[worker exception]\n{e}\n") code, used_bundled = 1, False _BG_JOBS[job_id].update( { "log": "".join(logs), "failures": failures, "progress": int(idx * 100 / len(tasks)), "updated": datetime.now().isoformat(), } ) _mlflow_event( run_name=f"space_bg_task_{task}", params={ "job_id": job_id, "task": task, "task4_phase": str(task4_phase), "experiment": model_bundle.get("experiment", ""), }, metrics={ "exit_code": float(code), "used_bundled": 1.0 if used_bundled else 0.0, "failures_so_far": float(failures), "progress_pct": float(_BG_JOBS[job_id]["progress"]), }, tags={"source": "hf_space", "mode": "background"}, ) if failures: _bundle_task_outputs(model_bundle, output_dir) _BG_JOBS[job_id].update( { "state": "done", "done": True, "progress": 100, "log": "".join(logs), "failures": failures, "updated": datetime.now().isoformat(), } ) _mlflow_event( run_name="space_bg_run", params={ "job_id": job_id, "task4_phase": str(task4_phase), "experiment": model_bundle.get("experiment", ""), "output_dir": str(output_dir), }, metrics={ "failures": float(failures), "elapsed_s": (time.perf_counter() - run_start), }, tags={"source": "hf_space", "mode": "background_summary"}, ) def start_run_all_background(model_bundle, output_dir, input_text, task4_phase, task5_cfg, quick_mode): if not model_bundle: raise gr.Error("Load a model first.") os.makedirs(output_dir, exist_ok=True) job_id = uuid.uuid4().hex[:10] _BG_JOBS[job_id] = { "state": "queued", "progress": 0, "log": "", "failures": 0, "done": False, "output_dir": output_dir, "created": datetime.now().isoformat(), "updated": datetime.now().isoformat(), "task_states": {k: "pending" for k in ["1", "2", "3", "4", "5"]}, } th = threading.Thread( target=_bg_worker, args=(job_id, model_bundle, output_dir, input_text, task4_phase, task5_cfg, bool(quick_mode)), daemon=True, ) th.start() flow = _build_flow_markdown(model_loaded=True, inference_ready=True, task_states=_BG_JOBS[job_id]["task_states"]) mode = "Quick" if quick_mode else "Full" return ( f"Background run started ({mode} Mode). Job ID: {job_id}", f"Job {job_id} queued...", job_id, _BG_JOBS[job_id]["task_states"], flow, ) def poll_run_all_background(job_id, output_dir): if not job_id or job_id not in _BG_JOBS: msg = "Background job idle. You can run a single task or start Run All 5 in background." empty = refresh_task_outputs(output_dir) flow = _build_flow_markdown(model_loaded=False, inference_ready=False, task_states={}) return msg, msg, {}, flow, *empty j = _BG_JOBS[job_id] status = ( f"Job {job_id} | state={j['state']} | progress={j['progress']}% | " f"failures={j['failures']} | updated={j['updated']}" ) outputs = refresh_task_outputs(output_dir) flow = _build_flow_markdown(model_loaded=True, inference_ready=True, task_states=j.get("task_states", {})) return status, j.get("log", ""), j.get("task_states", {}), flow, *outputs def run_single_task_and_refresh(model_bundle, task, output_dir, input_text, task4_phase, task5_cfg, quick_mode): status, log, task_states, flow = run_single_task( model_bundle, task, output_dir, input_text, task4_phase, task5_cfg, quick_mode ) out = refresh_task_outputs(output_dir) return status, log, task_states, flow, *out def run_single_task(model_bundle, task, output_dir, input_text, task4_phase, task5_cfg, quick_mode): if not model_bundle: raise gr.Error("Load a model first.") t0 = time.perf_counter() if quick_mode: code, log, used_bundled = _run_quick_task(task, model_bundle, input_text, task5_cfg) else: code, log, used_bundled = _run_analysis_cmd( task, model_bundle["ckpt_path"], output_dir, input_text, task4_phase, task5_cfg.get("samples", 50) ) task_states = {k: "pending" for k in ["1", "2", "3", "4", "5"]} task_states[str(task)] = "running" elapsed = (time.perf_counter() - t0) * 1000.0 if code != 0: _bundle_task_outputs(model_bundle, output_dir) try: log = f"{log}\n\n--- Live task analysis ---\n{_live_task_analysis(model_bundle, task, input_text, task5_cfg)}" status = f"Task {task} fallback mode: bundled reports + live input analysis." task_states[str(task)] = "done(live-fast)" except Exception as e: log = f"{log}\n\n--- Live task analysis failed ---\n{e}" status = f"Task {task} failed (and live fallback failed)." task_states[str(task)] = "failed" else: if used_bundled: _bundle_task_outputs(model_bundle, output_dir) log = f"{log}\n\n--- Live task analysis ---\n{_live_task_analysis(model_bundle, task, input_text, task5_cfg)}" status = f"Task {task} loaded from bundled analysis outputs + live analysis." task_states[str(task)] = "done(bundled)" else: status = f"Task {task} completed (exit={code})." task_states[str(task)] = "done" _mlflow_event( run_name=f"space_task_{task}", params={ "task": str(task), "task4_phase": str(task4_phase), "output_dir": str(output_dir), "experiment": model_bundle.get("experiment", ""), }, metrics={ "exit_code": float(code), "elapsed_ms": elapsed, "used_bundled": 1.0 if used_bundled else 0.0, }, tags={"source": "hf_space", "mode": "single_task"}, ) flow = _build_flow_markdown(model_loaded=True, inference_ready=True, task_states=task_states) return status, log, task_states, flow def _read_text(path): if not os.path.exists(path): return "Not found." with open(path, "r", encoding="utf-8", errors="ignore") as f: return f.read() def _img_or_none(path): return path if os.path.exists(path) else None def refresh_task_outputs(output_dir): task1_txt = _read_text(os.path.join(output_dir, "task1_kv_cache.txt")) task2_txt = _read_text(os.path.join(output_dir, "task2_report.txt")) task3_txt = _read_text(os.path.join(output_dir, "task3_report.txt")) task5_txt = _read_text(os.path.join(output_dir, "task5_report.txt")) task2_drift = _img_or_none(os.path.join(output_dir, "task2_semantic_drift.png")) task2_attn = _img_or_none(os.path.join(output_dir, "task2_attn_t0.png")) task2_evolution = _img_or_none(os.path.join(output_dir, "task2_attn_evolution.png")) # Show farthest diffusion step snapshot if available (t=max). task2_tmax = None try: cands = [] for name in os.listdir(output_dir): if name.startswith("task2_attn_t") and name.endswith(".png"): step = name.replace("task2_attn_t", "").replace(".png", "") if step.isdigit(): cands.append((int(step), os.path.join(output_dir, name))) if cands: cands.sort(key=lambda x: x[0], reverse=True) task2_tmax = cands[0][1] except Exception: task2_tmax = None task3_space = _img_or_none(os.path.join(output_dir, "task3_concept_space.png")) task4_plot = _img_or_none(os.path.join(output_dir, "task4_ablation_3d.png")) if task4_plot is None: task4_plot = _img_or_none(os.path.join(output_dir, "task4_3d.png")) return ( task1_txt, task2_txt, task2_drift, task2_attn, task2_tmax, task2_evolution, task3_txt, task3_space, task5_txt, task4_plot ) def _safe_refresh_task_outputs(output_dir): try: return refresh_task_outputs(output_dir) except Exception as e: err = f"Refresh error: {e}" return (err, err, None, None, None, None, err, None, err, None) def _safe_start_run_all_background( model_bundle, output_dir, input_text, task4_phase, current_job_id, lambda_min, lambda_max, lambda_step, task5_samples, quick_mode ): try: cfg = _task5_cfg(lambda_min, lambda_max, lambda_step, task5_samples) status, log, job_id, task_states, flow = start_run_all_background( model_bundle, output_dir, input_text, task4_phase, cfg, quick_mode ) return status, log, job_id, task_states, flow except Exception as e: err = f"Background start failed: {e}" return err, err, current_job_id, {}, _build_flow_markdown(model_loaded=bool(model_bundle), inference_ready=False, task_states={}) def _safe_poll_run_all_background(job_id, output_dir): try: return poll_run_all_background(job_id, output_dir) except Exception as e: err = f"Track error: {e}" out = _safe_refresh_task_outputs(output_dir) return err, err, {}, _build_flow_markdown(model_loaded=False, inference_ready=False, task_states={}), *out def _safe_run_single_task_and_refresh( model_bundle, task, output_dir, input_text, task4_phase, lambda_min, lambda_max, lambda_step, task5_samples, quick_mode ): try: cfg = _task5_cfg(lambda_min, lambda_max, lambda_step, task5_samples) return run_single_task_and_refresh(model_bundle, task, output_dir, input_text, task4_phase, cfg, quick_mode) except Exception as e: err = f"Task {task} failed: {e}" out = _safe_refresh_task_outputs(output_dir) return err, err, {}, _build_flow_markdown(model_loaded=bool(model_bundle), inference_ready=False, task_states={}), *out def _generate_with_flow( model_bundle, input_text, temperature, top_k, repetition_penalty, diversity_penalty, num_steps, clean_output, ): out_text, status, meta = generate_from_ui( model_bundle, input_text, temperature, top_k, repetition_penalty, diversity_penalty, num_steps, clean_output, ) flow = _build_flow_markdown(model_loaded=True, inference_ready=True, task_states={}) return out_text, status, meta, flow def load_selected_model_with_outputs(checkpoint_label): bundle, status, info, steps, out_dir = load_selected_model(checkpoint_label) outputs = _safe_refresh_task_outputs(out_dir) flow = _build_flow_markdown(model_loaded=True, inference_ready=False, task_states={}) return bundle, status, info, steps, out_dir, flow, *outputs def auto_load_default_with_outputs(): choices = list(checkpoint_map().keys()) if not choices: empty = _safe_refresh_task_outputs(DEFAULT_ANALYSIS_OUT) return None, "No checkpoints found.", {}, 64, DEFAULT_ANALYSIS_OUT, _build_flow_markdown(model_loaded=False, inference_ready=False, task_states={}), *empty return load_selected_model_with_outputs(default_checkpoint_label()) def safe_load_selected_model_with_outputs(checkpoint_label): try: return load_selected_model_with_outputs(checkpoint_label) except Exception as e: outputs = _safe_refresh_task_outputs(DEFAULT_ANALYSIS_OUT) return ( None, f"Load failed: {e}", {}, 64, DEFAULT_ANALYSIS_OUT, _build_flow_markdown(model_loaded=False, inference_ready=False, task_states={}), *outputs, ) CUSTOM_CSS = """ :root { --bg1: #f5fbff; --bg2: #f2f7ef; --card: #ffffff; --line: #d9e6f2; --ink: #163048; } .gradio-container { background: linear-gradient(130deg, var(--bg1), var(--bg2)); color: var(--ink); } #hero { background: radial-gradient(110% 130% at 0% 0%, #d7ebff 0%, #ecf6ff 55%, #f8fbff 100%); border: 1px solid #cfe0f1; border-radius: 16px; padding: 18px 20px; } .panel { background: var(--card); border: 1px solid var(--line); border-radius: 14px; } """ with gr.Blocks(title="Sanskrit Diffusion Model", css=CUSTOM_CSS) as demo: model_state = gr.State(None) bg_job_state = gr.State("") gr.Markdown( """
Select any trained model, run all 5 analysis tasks or individual tasks, then test inference with user-controlled parameters.