import copy import json import os import subprocess import sys from datetime import datetime import gradio as gr import torch from huggingface_hub import hf_hub_download, list_repo_files from gradio.blocks import Blocks as _GradioBlocks from config import CONFIG from inference import _build_tokenizers, _resolve_device, load_model, run_inference RESULTS_DIR = "generated_results" DEFAULT_ANALYSIS_OUT = "analysis/outputs" os.makedirs(RESULTS_DIR, exist_ok=True) MODEL_CACHE = {} # HF Spaces currently installs gradio[oauth]==5.0.0. In that stack, API schema # generation can crash with: # TypeError: argument of type 'bool' is not iterable # Guard it so UI still serves even if API metadata generation fails. _ORIG_GET_API_INFO = _GradioBlocks.get_api_info def _safe_get_api_info(self): try: return _ORIG_GET_API_INFO(self) except TypeError as e: if "bool' is not iterable" in str(e): return {"named_endpoints": {}, "unnamed_endpoints": {}} raise _GradioBlocks.get_api_info = _safe_get_api_info def discover_checkpoints(): found = [] for root in ("ablation_results", "results7", "results"): if not os.path.isdir(root): continue for entry in sorted(os.listdir(root)): ckpt = os.path.join(root, entry, "best_model.pt") if not os.path.exists(ckpt): continue found.append( { "label": f"{entry} [{root}]", "path": ckpt, "experiment": entry, "root": root, } ) repo = os.getenv("HF_CHECKPOINT_REPO", "").strip() if repo: branch = os.getenv("HF_CHECKPOINT_REVISION", "main").strip() or "main" try: for fname in list_repo_files(repo_id=repo, repo_type="model", revision=branch): if not fname.endswith("/best_model.pt") and fname != "best_model.pt": continue local_path = hf_hub_download(repo_id=repo, filename=fname, revision=branch, repo_type="model") parent = os.path.basename(os.path.dirname(fname)) if "/" in fname else "remote" root = os.path.dirname(fname).split("/")[0] if "/" in fname else "remote" found.append( { "label": f"{parent} [hf:{repo}]", "path": local_path, "experiment": parent, "root": root, } ) except Exception as e: print(f"[WARN] Could not discover remote checkpoints from {repo}: {e}") return found def checkpoint_map(): return {item["label"]: item for item in discover_checkpoints()} def default_checkpoint_label(): cps = discover_checkpoints() if not cps: return None for item in cps: if item["path"].endswith("ablation_results/T4/best_model.pt"): return item["label"] return cps[0]["label"] def infer_model_type(experiment_name: str, root: str = "") -> str: if root == "ablation_results": return "d3pm_cross_attention" if experiment_name.startswith("d3pm_cross_attention"): return "d3pm_cross_attention" if experiment_name.startswith("d3pm_encoder_decoder"): return "d3pm_encoder_decoder" if experiment_name.startswith("baseline_cross_attention"): return "baseline_cross_attention" if experiment_name.startswith("baseline_encoder_decoder"): return "baseline_encoder_decoder" return CONFIG["model_type"] def infer_include_negative(experiment_name: str, root: str = "") -> bool: if root == "ablation_results": return False if "_neg_True" in experiment_name: return True if "_neg_False" in experiment_name: return False return CONFIG["data"]["include_negative_examples"] def build_runtime_cfg(ckpt_path: str): experiment = os.path.basename(os.path.dirname(ckpt_path)) or "remote" root = os.path.basename(os.path.dirname(os.path.dirname(ckpt_path))) or "remote" cfg = copy.deepcopy(CONFIG) cfg["model_type"] = infer_model_type(experiment, root=root) cfg["data"]["include_negative_examples"] = infer_include_negative(experiment, root=root) if root == "ablation_results" and experiment.startswith("T") and experiment[1:].isdigit(): t_val = int(experiment[1:]) cfg["model"]["diffusion_steps"] = t_val cfg["inference"]["num_steps"] = t_val device = _resolve_device(cfg) return cfg, device, experiment def load_selected_model(checkpoint_label): mapping = checkpoint_map() if checkpoint_label not in mapping: raise gr.Error("Selected checkpoint not found. Click refresh.") ckpt_path = mapping[checkpoint_label]["path"] cfg, device, experiment = build_runtime_cfg(ckpt_path) model, cfg = load_model(ckpt_path, cfg, device) src_tok, tgt_tok = _build_tokenizers(cfg) bundle = { "ckpt_path": ckpt_path, "experiment": experiment, "device": str(device), "cfg": cfg, "model": model, "src_tok": src_tok, "tgt_tok": tgt_tok, } MODEL_CACHE[checkpoint_label] = bundle model_info = { "checkpoint": ckpt_path, "experiment": experiment, "model_type": cfg["model_type"], "include_negatives": cfg["data"]["include_negative_examples"], "device": str(device), "max_seq_len": cfg["model"]["max_seq_len"], "diffusion_steps": cfg["model"]["diffusion_steps"], "inference_steps": cfg["inference"]["num_steps"], "d_model": cfg["model"]["d_model"], "n_layers": cfg["model"]["n_layers"], "n_heads": cfg["model"]["n_heads"], } status = f"Loaded `{experiment}` on `{device}` (`{cfg['model_type']}`)" suggested_out = os.path.join("analysis", "outputs_ui", experiment) return checkpoint_label, status, json.dumps(model_info, ensure_ascii=False, indent=2), cfg["inference"]["num_steps"], suggested_out def _get_bundle(model_key: str): if not model_key: raise gr.Error("Load a model first.") if model_key not in MODEL_CACHE: mapping = checkpoint_map() if model_key not in mapping: raise gr.Error("Selected checkpoint is no longer available. Refresh and load again.") # Lazy reload if Space process restarted. load_selected_model(model_key) return MODEL_CACHE[model_key] def apply_preset(preset_name): presets = { "Manual": (0.70, 40, 1.20, 0.0), "Literal": (0.60, 20, 1.25, 0.0), "Balanced": (0.70, 40, 1.20, 0.0), "Creative": (0.90, 80, 1.05, 0.2), } return presets.get(preset_name, presets["Balanced"]) def clean_generated_text(text: str, max_consecutive: int = 2) -> str: text = " ".join(text.split()) if not text: return text tokens = text.split() cleaned = [] prev = None run = 0 for tok in tokens: if tok == prev: run += 1 else: prev = tok run = 1 if run <= max_consecutive: cleaned.append(tok) out = " ".join(cleaned).replace(" ।", "।").replace(" ॥", "॥") return " ".join(out.split()) def save_generation(experiment, record): ts = datetime.now().strftime("%Y%m%d") path = os.path.join(RESULTS_DIR, f"{experiment}_ui_{ts}.json") existing = [] if os.path.exists(path): with open(path, "r", encoding="utf-8") as f: existing = json.load(f) existing.append(record) with open(path, "w", encoding="utf-8") as f: json.dump(existing, f, ensure_ascii=False, indent=2) return path def generate_from_ui( model_key, input_text, temperature, top_k, repetition_penalty, diversity_penalty, num_steps, clean_output, ): model_bundle = _get_bundle(model_key) if not input_text.strip(): raise gr.Error("Enter input text first.") cfg = copy.deepcopy(model_bundle["cfg"]) cfg["inference"]["temperature"] = float(temperature) cfg["inference"]["top_k"] = int(top_k) cfg["inference"]["repetition_penalty"] = float(repetition_penalty) cfg["inference"]["diversity_penalty"] = float(diversity_penalty) cfg["inference"]["num_steps"] = int(num_steps) src_tok = model_bundle["src_tok"] tgt_tok = model_bundle["tgt_tok"] device = torch.device(model_bundle["device"]) input_ids = torch.tensor([src_tok.encode(input_text.strip())], dtype=torch.long, device=device) out = run_inference(model_bundle["model"], input_ids, cfg) # Align decode with validation style: strip only special ids. pad_id = 1 mask_id = cfg["diffusion"]["mask_token_id"] decoded_ids = [x for x in out[0].tolist() if x not in (pad_id, mask_id)] raw_output_text = tgt_tok.decode(decoded_ids).strip() output_text = clean_generated_text(raw_output_text) if clean_output else raw_output_text if not output_text: output_text = "(empty output)" record = { "timestamp": datetime.now().isoformat(timespec="seconds"), "experiment": model_bundle["experiment"], "checkpoint": model_bundle["ckpt_path"], "input_text": input_text, "raw_output_text": raw_output_text, "output_text": output_text, "temperature": float(temperature), "top_k": int(top_k), "repetition_penalty": float(repetition_penalty), "diversity_penalty": float(diversity_penalty), "num_steps": int(num_steps), "clean_output": bool(clean_output), } log_path = save_generation(model_bundle["experiment"], record) status = f"Inference done. Saved: `{log_path}`" return output_text, status, json.dumps(record, ensure_ascii=False, indent=2) def _run_analysis_cmd(task, ckpt_path, output_dir, input_text="dharmo rakṣati rakṣitaḥ", phase="analyze"): os.makedirs(output_dir, exist_ok=True) cmd = [ sys.executable, "analysis/run_analysis.py", "--task", str(task), "--checkpoint", ckpt_path, "--output_dir", output_dir, ] if str(task) == "2" or str(task) == "all": cmd.extend(["--input", input_text]) if str(task) == "4": cmd.extend(["--phase", phase]) env = os.environ.copy() env.setdefault("HF_HOME", "/tmp/hf_home") env.setdefault("HF_DATASETS_CACHE", "/tmp/hf_datasets") env.setdefault("HF_HUB_CACHE", "/tmp/hf_hub") proc = subprocess.run(cmd, capture_output=True, text=True, env=env) log = f"$ {' '.join(cmd)}\n\n{proc.stdout}\n{proc.stderr}" return proc.returncode, log def run_single_task(model_key, task, output_dir, input_text, task4_phase): model_bundle = _get_bundle(model_key) code, log = _run_analysis_cmd(task, model_bundle["ckpt_path"], output_dir, input_text, task4_phase) status = f"Task {task} {'completed' if code == 0 else 'failed'} (exit={code})." return status, log def run_all_tasks(model_key, output_dir, input_text, task4_phase): model_bundle = _get_bundle(model_key) logs = [] failures = 0 for task in ["1", "2", "3", "4", "5"]: code, log = _run_analysis_cmd(task, model_bundle["ckpt_path"], output_dir, input_text, task4_phase) logs.append(f"\n\n{'='*22} TASK {task} {'='*22}\n{log}") if code != 0: failures += 1 status = f"Run-all finished with {failures} failed task(s)." if failures else "All 5 tasks completed." return status, "".join(logs) def _read_text(path): if not os.path.exists(path): return "Not found." with open(path, "r", encoding="utf-8", errors="ignore") as f: return f.read() def _img_or_none(path): return path if os.path.exists(path) else None def refresh_task_outputs(output_dir): task1_txt = _read_text(os.path.join(output_dir, "task1_kv_cache.txt")) task2_txt = _read_text(os.path.join(output_dir, "task2_report.txt")) task3_txt = _read_text(os.path.join(output_dir, "task3_report.txt")) task5_txt = _read_text(os.path.join(output_dir, "task5_report.txt")) task2_drift = _img_or_none(os.path.join(output_dir, "task2_semantic_drift.png")) task2_attn = _img_or_none(os.path.join(output_dir, "task2_attn_t0.png")) task3_space = _img_or_none(os.path.join(output_dir, "task3_concept_space.png")) task4_plot = _img_or_none(os.path.join(output_dir, "task4_ablation_3d.png")) return task1_txt, task2_txt, task2_drift, task2_attn, task3_txt, task3_space, task5_txt, task4_plot CUSTOM_CSS = """ :root { --bg1: #f5fbff; --bg2: #f2f7ef; --card: #ffffff; --line: #d9e6f2; --ink: #163048; } .gradio-container { background: linear-gradient(130deg, var(--bg1), var(--bg2)); color: var(--ink); } #hero { background: radial-gradient(110% 130% at 0% 0%, #d7ebff 0%, #ecf6ff 55%, #f8fbff 100%); border: 1px solid #cfe0f1; border-radius: 16px; padding: 18px 20px; } .panel { background: var(--card); border: 1px solid var(--line); border-radius: 14px; } """ with gr.Blocks(title="Sanskrit Diffusion Client Demo", css=CUSTOM_CSS) as demo: model_state = gr.State("") gr.Markdown( """
Select any trained model, run all 5 analysis tasks or individual tasks, then test inference with user-controlled parameters.