DevaFlow-space / app.py
bhsinghgrid's picture
Hotfix: guard Gradio 5.0 API schema bug
9124d18 verified
import copy
import json
import os
import subprocess
import sys
from datetime import datetime
import gradio as gr
import torch
from huggingface_hub import hf_hub_download, list_repo_files
from gradio.blocks import Blocks as _GradioBlocks
from config import CONFIG
from inference import _build_tokenizers, _resolve_device, load_model, run_inference
RESULTS_DIR = "generated_results"
DEFAULT_ANALYSIS_OUT = "analysis/outputs"
os.makedirs(RESULTS_DIR, exist_ok=True)
MODEL_CACHE = {}
# HF Spaces currently installs gradio[oauth]==5.0.0. In that stack, API schema
# generation can crash with:
# TypeError: argument of type 'bool' is not iterable
# Guard it so UI still serves even if API metadata generation fails.
_ORIG_GET_API_INFO = _GradioBlocks.get_api_info
def _safe_get_api_info(self):
try:
return _ORIG_GET_API_INFO(self)
except TypeError as e:
if "bool' is not iterable" in str(e):
return {"named_endpoints": {}, "unnamed_endpoints": {}}
raise
_GradioBlocks.get_api_info = _safe_get_api_info
def discover_checkpoints():
found = []
for root in ("ablation_results", "results7", "results"):
if not os.path.isdir(root):
continue
for entry in sorted(os.listdir(root)):
ckpt = os.path.join(root, entry, "best_model.pt")
if not os.path.exists(ckpt):
continue
found.append(
{
"label": f"{entry} [{root}]",
"path": ckpt,
"experiment": entry,
"root": root,
}
)
repo = os.getenv("HF_CHECKPOINT_REPO", "").strip()
if repo:
branch = os.getenv("HF_CHECKPOINT_REVISION", "main").strip() or "main"
try:
for fname in list_repo_files(repo_id=repo, repo_type="model", revision=branch):
if not fname.endswith("/best_model.pt") and fname != "best_model.pt":
continue
local_path = hf_hub_download(repo_id=repo, filename=fname, revision=branch, repo_type="model")
parent = os.path.basename(os.path.dirname(fname)) if "/" in fname else "remote"
root = os.path.dirname(fname).split("/")[0] if "/" in fname else "remote"
found.append(
{
"label": f"{parent} [hf:{repo}]",
"path": local_path,
"experiment": parent,
"root": root,
}
)
except Exception as e:
print(f"[WARN] Could not discover remote checkpoints from {repo}: {e}")
return found
def checkpoint_map():
return {item["label"]: item for item in discover_checkpoints()}
def default_checkpoint_label():
cps = discover_checkpoints()
if not cps:
return None
for item in cps:
if item["path"].endswith("ablation_results/T4/best_model.pt"):
return item["label"]
return cps[0]["label"]
def infer_model_type(experiment_name: str, root: str = "") -> str:
if root == "ablation_results":
return "d3pm_cross_attention"
if experiment_name.startswith("d3pm_cross_attention"):
return "d3pm_cross_attention"
if experiment_name.startswith("d3pm_encoder_decoder"):
return "d3pm_encoder_decoder"
if experiment_name.startswith("baseline_cross_attention"):
return "baseline_cross_attention"
if experiment_name.startswith("baseline_encoder_decoder"):
return "baseline_encoder_decoder"
return CONFIG["model_type"]
def infer_include_negative(experiment_name: str, root: str = "") -> bool:
if root == "ablation_results":
return False
if "_neg_True" in experiment_name:
return True
if "_neg_False" in experiment_name:
return False
return CONFIG["data"]["include_negative_examples"]
def build_runtime_cfg(ckpt_path: str):
experiment = os.path.basename(os.path.dirname(ckpt_path)) or "remote"
root = os.path.basename(os.path.dirname(os.path.dirname(ckpt_path))) or "remote"
cfg = copy.deepcopy(CONFIG)
cfg["model_type"] = infer_model_type(experiment, root=root)
cfg["data"]["include_negative_examples"] = infer_include_negative(experiment, root=root)
if root == "ablation_results" and experiment.startswith("T") and experiment[1:].isdigit():
t_val = int(experiment[1:])
cfg["model"]["diffusion_steps"] = t_val
cfg["inference"]["num_steps"] = t_val
device = _resolve_device(cfg)
return cfg, device, experiment
def load_selected_model(checkpoint_label):
mapping = checkpoint_map()
if checkpoint_label not in mapping:
raise gr.Error("Selected checkpoint not found. Click refresh.")
ckpt_path = mapping[checkpoint_label]["path"]
cfg, device, experiment = build_runtime_cfg(ckpt_path)
model, cfg = load_model(ckpt_path, cfg, device)
src_tok, tgt_tok = _build_tokenizers(cfg)
bundle = {
"ckpt_path": ckpt_path,
"experiment": experiment,
"device": str(device),
"cfg": cfg,
"model": model,
"src_tok": src_tok,
"tgt_tok": tgt_tok,
}
MODEL_CACHE[checkpoint_label] = bundle
model_info = {
"checkpoint": ckpt_path,
"experiment": experiment,
"model_type": cfg["model_type"],
"include_negatives": cfg["data"]["include_negative_examples"],
"device": str(device),
"max_seq_len": cfg["model"]["max_seq_len"],
"diffusion_steps": cfg["model"]["diffusion_steps"],
"inference_steps": cfg["inference"]["num_steps"],
"d_model": cfg["model"]["d_model"],
"n_layers": cfg["model"]["n_layers"],
"n_heads": cfg["model"]["n_heads"],
}
status = f"Loaded `{experiment}` on `{device}` (`{cfg['model_type']}`)"
suggested_out = os.path.join("analysis", "outputs_ui", experiment)
return checkpoint_label, status, json.dumps(model_info, ensure_ascii=False, indent=2), cfg["inference"]["num_steps"], suggested_out
def _get_bundle(model_key: str):
if not model_key:
raise gr.Error("Load a model first.")
if model_key not in MODEL_CACHE:
mapping = checkpoint_map()
if model_key not in mapping:
raise gr.Error("Selected checkpoint is no longer available. Refresh and load again.")
# Lazy reload if Space process restarted.
load_selected_model(model_key)
return MODEL_CACHE[model_key]
def apply_preset(preset_name):
presets = {
"Manual": (0.70, 40, 1.20, 0.0),
"Literal": (0.60, 20, 1.25, 0.0),
"Balanced": (0.70, 40, 1.20, 0.0),
"Creative": (0.90, 80, 1.05, 0.2),
}
return presets.get(preset_name, presets["Balanced"])
def clean_generated_text(text: str, max_consecutive: int = 2) -> str:
text = " ".join(text.split())
if not text:
return text
tokens = text.split()
cleaned = []
prev = None
run = 0
for tok in tokens:
if tok == prev:
run += 1
else:
prev = tok
run = 1
if run <= max_consecutive:
cleaned.append(tok)
out = " ".join(cleaned).replace(" ।", "।").replace(" ॥", "॥")
return " ".join(out.split())
def save_generation(experiment, record):
ts = datetime.now().strftime("%Y%m%d")
path = os.path.join(RESULTS_DIR, f"{experiment}_ui_{ts}.json")
existing = []
if os.path.exists(path):
with open(path, "r", encoding="utf-8") as f:
existing = json.load(f)
existing.append(record)
with open(path, "w", encoding="utf-8") as f:
json.dump(existing, f, ensure_ascii=False, indent=2)
return path
def generate_from_ui(
model_key,
input_text,
temperature,
top_k,
repetition_penalty,
diversity_penalty,
num_steps,
clean_output,
):
model_bundle = _get_bundle(model_key)
if not input_text.strip():
raise gr.Error("Enter input text first.")
cfg = copy.deepcopy(model_bundle["cfg"])
cfg["inference"]["temperature"] = float(temperature)
cfg["inference"]["top_k"] = int(top_k)
cfg["inference"]["repetition_penalty"] = float(repetition_penalty)
cfg["inference"]["diversity_penalty"] = float(diversity_penalty)
cfg["inference"]["num_steps"] = int(num_steps)
src_tok = model_bundle["src_tok"]
tgt_tok = model_bundle["tgt_tok"]
device = torch.device(model_bundle["device"])
input_ids = torch.tensor([src_tok.encode(input_text.strip())], dtype=torch.long, device=device)
out = run_inference(model_bundle["model"], input_ids, cfg)
# Align decode with validation style: strip only special ids.
pad_id = 1
mask_id = cfg["diffusion"]["mask_token_id"]
decoded_ids = [x for x in out[0].tolist() if x not in (pad_id, mask_id)]
raw_output_text = tgt_tok.decode(decoded_ids).strip()
output_text = clean_generated_text(raw_output_text) if clean_output else raw_output_text
if not output_text:
output_text = "(empty output)"
record = {
"timestamp": datetime.now().isoformat(timespec="seconds"),
"experiment": model_bundle["experiment"],
"checkpoint": model_bundle["ckpt_path"],
"input_text": input_text,
"raw_output_text": raw_output_text,
"output_text": output_text,
"temperature": float(temperature),
"top_k": int(top_k),
"repetition_penalty": float(repetition_penalty),
"diversity_penalty": float(diversity_penalty),
"num_steps": int(num_steps),
"clean_output": bool(clean_output),
}
log_path = save_generation(model_bundle["experiment"], record)
status = f"Inference done. Saved: `{log_path}`"
return output_text, status, json.dumps(record, ensure_ascii=False, indent=2)
def _run_analysis_cmd(task, ckpt_path, output_dir, input_text="dharmo rakṣati rakṣitaḥ", phase="analyze"):
os.makedirs(output_dir, exist_ok=True)
cmd = [
sys.executable,
"analysis/run_analysis.py",
"--task",
str(task),
"--checkpoint",
ckpt_path,
"--output_dir",
output_dir,
]
if str(task) == "2" or str(task) == "all":
cmd.extend(["--input", input_text])
if str(task) == "4":
cmd.extend(["--phase", phase])
env = os.environ.copy()
env.setdefault("HF_HOME", "/tmp/hf_home")
env.setdefault("HF_DATASETS_CACHE", "/tmp/hf_datasets")
env.setdefault("HF_HUB_CACHE", "/tmp/hf_hub")
proc = subprocess.run(cmd, capture_output=True, text=True, env=env)
log = f"$ {' '.join(cmd)}\n\n{proc.stdout}\n{proc.stderr}"
return proc.returncode, log
def run_single_task(model_key, task, output_dir, input_text, task4_phase):
model_bundle = _get_bundle(model_key)
code, log = _run_analysis_cmd(task, model_bundle["ckpt_path"], output_dir, input_text, task4_phase)
status = f"Task {task} {'completed' if code == 0 else 'failed'} (exit={code})."
return status, log
def run_all_tasks(model_key, output_dir, input_text, task4_phase):
model_bundle = _get_bundle(model_key)
logs = []
failures = 0
for task in ["1", "2", "3", "4", "5"]:
code, log = _run_analysis_cmd(task, model_bundle["ckpt_path"], output_dir, input_text, task4_phase)
logs.append(f"\n\n{'='*22} TASK {task} {'='*22}\n{log}")
if code != 0:
failures += 1
status = f"Run-all finished with {failures} failed task(s)." if failures else "All 5 tasks completed."
return status, "".join(logs)
def _read_text(path):
if not os.path.exists(path):
return "Not found."
with open(path, "r", encoding="utf-8", errors="ignore") as f:
return f.read()
def _img_or_none(path):
return path if os.path.exists(path) else None
def refresh_task_outputs(output_dir):
task1_txt = _read_text(os.path.join(output_dir, "task1_kv_cache.txt"))
task2_txt = _read_text(os.path.join(output_dir, "task2_report.txt"))
task3_txt = _read_text(os.path.join(output_dir, "task3_report.txt"))
task5_txt = _read_text(os.path.join(output_dir, "task5_report.txt"))
task2_drift = _img_or_none(os.path.join(output_dir, "task2_semantic_drift.png"))
task2_attn = _img_or_none(os.path.join(output_dir, "task2_attn_t0.png"))
task3_space = _img_or_none(os.path.join(output_dir, "task3_concept_space.png"))
task4_plot = _img_or_none(os.path.join(output_dir, "task4_ablation_3d.png"))
return task1_txt, task2_txt, task2_drift, task2_attn, task3_txt, task3_space, task5_txt, task4_plot
CUSTOM_CSS = """
:root {
--bg1: #f5fbff;
--bg2: #f2f7ef;
--card: #ffffff;
--line: #d9e6f2;
--ink: #163048;
}
.gradio-container {
background: linear-gradient(130deg, var(--bg1), var(--bg2));
color: var(--ink);
}
#hero {
background: radial-gradient(110% 130% at 0% 0%, #d7ebff 0%, #ecf6ff 55%, #f8fbff 100%);
border: 1px solid #cfe0f1;
border-radius: 16px;
padding: 18px 20px;
}
.panel {
background: var(--card);
border: 1px solid var(--line);
border-radius: 14px;
}
"""
with gr.Blocks(title="Sanskrit Diffusion Client Demo", css=CUSTOM_CSS) as demo:
model_state = gr.State("")
gr.Markdown(
"""
<div id="hero">
<h1 style="margin:0;">Sanskrit Diffusion Client Demo</h1>
<p style="margin:.5rem 0 0 0;">
Select any trained model, run all 5 analysis tasks or individual tasks, then test inference with user-controlled parameters.
</p>
</div>
"""
)
with gr.Row():
with gr.Column(scale=2, elem_classes=["panel"]):
checkpoint_dropdown = gr.Dropdown(
label="Model Checkpoint",
choices=list(checkpoint_map().keys()),
value=default_checkpoint_label(),
interactive=True,
)
with gr.Column(scale=1, elem_classes=["panel"]):
refresh_btn = gr.Button("Refresh Models")
load_btn = gr.Button("Load Selected Model", variant="primary")
load_status = gr.Markdown("Select a model and load.")
model_info = gr.Textbox(label="Loaded Model Details (JSON)", lines=12, interactive=False)
with gr.Tabs():
with gr.Tab("1) Task Runner"):
with gr.Row():
with gr.Column(scale=2):
analysis_output_dir = gr.Textbox(
label="Analysis Output Directory",
value=DEFAULT_ANALYSIS_OUT,
)
analysis_input = gr.Textbox(
label="Task 2 Input Text",
value="dharmo rakṣati rakṣitaḥ",
lines=2,
)
with gr.Column(scale=1):
task4_phase = gr.Dropdown(
choices=["analyze", "generate_configs"],
value="analyze",
label="Task 4 Phase",
)
run_all_btn = gr.Button("Run All 5 Tasks", variant="primary")
with gr.Row():
task_choice = gr.Dropdown(
choices=["1", "2", "3", "4", "5"],
value="1",
label="Single Task",
)
run_single_btn = gr.Button("Run Selected Task")
refresh_outputs_btn = gr.Button("Refresh Output Viewer")
task_run_status = gr.Markdown("")
task_run_log = gr.Textbox(label="Task Execution Log", lines=18, interactive=False)
with gr.Accordion("Task Outputs Viewer", open=True):
task1_box = gr.Textbox(label="Task 1 Report", lines=10, interactive=False)
task2_box = gr.Textbox(label="Task 2 Report", lines=10, interactive=False)
with gr.Row():
task2_drift_img = gr.Image(label="Task2 Drift", type="filepath")
task2_attn_img = gr.Image(label="Task2 Attention", type="filepath")
task3_box = gr.Textbox(label="Task 3 Report", lines=10, interactive=False)
task3_img = gr.Image(label="Task3 Concept Space", type="filepath")
task5_box = gr.Textbox(label="Task 5 Report", lines=10, interactive=False)
task4_img = gr.Image(label="Task4 3D Ablation Plot", type="filepath")
with gr.Tab("2) Inference Playground"):
with gr.Row():
with gr.Column(scale=2):
input_text = gr.Textbox(
label="Input (Roman / IAST)",
lines=4,
value="dharmo rakṣati rakṣitaḥ",
)
output_text = gr.Textbox(
label="Output (Devanagari)",
lines=7,
interactive=False,
)
run_status = gr.Markdown("")
run_record = gr.Textbox(label="Inference Metadata (JSON)", lines=12, interactive=False)
with gr.Column(scale=1, elem_classes=["panel"]):
preset = gr.Radio(["Manual", "Literal", "Balanced", "Creative"], value="Balanced", label="Preset")
temperature = gr.Slider(0.4, 1.2, value=0.70, step=0.05, label="Temperature")
top_k = gr.Slider(5, 100, value=40, step=1, label="Top-K")
repetition_penalty = gr.Slider(1.0, 3.0, value=1.20, step=0.05, label="Repetition Penalty")
diversity_penalty = gr.Slider(0.0, 1.0, value=0.0, step=0.05, label="Diversity Penalty")
num_steps = gr.Slider(1, 128, value=64, step=1, label="Inference Steps")
clean_output = gr.Checkbox(value=True, label="Clean Output")
generate_btn = gr.Button("Generate", variant="primary")
gr.Examples(
examples=[
["dharmo rakṣati rakṣitaḥ"],
["satyameva jayate"],
["yadā mano nivarteta viṣayebhyaḥ svabhāvataḥ"],
],
inputs=[input_text],
)
def refresh_checkpoints():
choices = list(checkpoint_map().keys())
value = default_checkpoint_label() if choices else None
return gr.update(choices=choices, value=value)
refresh_btn.click(fn=refresh_checkpoints, outputs=[checkpoint_dropdown])
load_btn.click(
fn=load_selected_model,
inputs=[checkpoint_dropdown],
outputs=[model_state, load_status, model_info, num_steps, analysis_output_dir],
)
preset.change(
fn=apply_preset,
inputs=[preset],
outputs=[temperature, top_k, repetition_penalty, diversity_penalty],
)
generate_btn.click(
fn=generate_from_ui,
inputs=[
model_state,
input_text,
temperature,
top_k,
repetition_penalty,
diversity_penalty,
num_steps,
clean_output,
],
outputs=[output_text, run_status, run_record],
)
input_text.submit(
fn=generate_from_ui,
inputs=[
model_state,
input_text,
temperature,
top_k,
repetition_penalty,
diversity_penalty,
num_steps,
clean_output,
],
outputs=[output_text, run_status, run_record],
)
run_single_btn.click(
fn=run_single_task,
inputs=[model_state, task_choice, analysis_output_dir, analysis_input, task4_phase],
outputs=[task_run_status, task_run_log],
)
run_all_btn.click(
fn=run_all_tasks,
inputs=[model_state, analysis_output_dir, analysis_input, task4_phase],
outputs=[task_run_status, task_run_log],
)
refresh_outputs_btn.click(
fn=refresh_task_outputs,
inputs=[analysis_output_dir],
outputs=[
task1_box,
task2_box,
task2_drift_img,
task2_attn_img,
task3_box,
task3_img,
task5_box,
task4_img,
],
)
demo.load(
fn=refresh_task_outputs,
inputs=[analysis_output_dir],
outputs=[
task1_box,
task2_box,
task2_drift_img,
task2_attn_img,
task3_box,
task3_img,
task5_box,
task4_img,
],
)
if __name__ == "__main__":
port = int(os.environ["GRADIO_SERVER_PORT"]) if "GRADIO_SERVER_PORT" in os.environ else None
demo.launch(server_name="0.0.0.0", server_port=port, share=False, show_api=False)