devflow / app.py
bhsinghgrid's picture
Upload folder using huggingface_hub
3a0ae5e verified
raw
history blame
24 kB
import copy
import json
import os
import subprocess
import sys
import shutil
from datetime import datetime
from pathlib import Path
import gradio as gr
import torch
from huggingface_hub import hf_hub_download
from config import CONFIG
from inference import _resolve_device, load_model, run_inference, _decode_clean, _decode_with_cleanup
from model.tokenizer import SanskritSourceTokenizer, SanskritTargetTokenizer
RESULTS_DIR = "generated_results"
DEFAULT_ANALYSIS_OUT = "analysis/outputs"
os.makedirs(RESULTS_DIR, exist_ok=True)
HF_DEFAULT_MODEL_REPO = os.environ.get("HF_DEFAULT_MODEL_REPO", "bhsinghgrid/DevaFlow")
HF_DEFAULT_MODEL_FILE = os.environ.get("HF_DEFAULT_MODEL_FILE", "best_model.pt")
def _download_hf_default_checkpoint():
try:
cache_dir = Path(".hf_model_cache")
cache_dir.mkdir(parents=True, exist_ok=True)
ckpt = hf_hub_download(
repo_id=HF_DEFAULT_MODEL_REPO,
filename=HF_DEFAULT_MODEL_FILE,
local_dir=str(cache_dir),
local_dir_use_symlinks=False,
)
return ckpt
except Exception:
return None
def discover_checkpoints():
found = []
for root in ("ablation_results", "results7", "results"):
if not os.path.isdir(root):
continue
for entry in sorted(os.listdir(root)):
ckpt = os.path.join(root, entry, "best_model.pt")
if not os.path.exists(ckpt):
continue
found.append(
{
"label": f"{entry} [{root}]",
"path": ckpt,
"experiment": entry,
"root": root,
}
)
# Space-safe fallback: always expose one downloadable checkpoint option.
hf_ckpt = _download_hf_default_checkpoint()
if hf_ckpt and os.path.exists(hf_ckpt):
found.append(
{
"label": f"HF default [{HF_DEFAULT_MODEL_REPO}]",
"path": hf_ckpt,
"experiment": "hf_default",
"root": "hf",
}
)
return found
def _guess_analysis_dir(experiment: str, ckpt_path: str) -> str:
base = Path("analysis_outputs")
if base.exists():
if experiment and (base / experiment).is_dir():
return str(base / experiment)
for part in Path(ckpt_path).parts:
if part.startswith("T") and part[1:].isdigit() and (base / part).is_dir():
return str(base / part)
if (base / "T4").is_dir():
return str(base / "T4")
return os.path.join("analysis", "outputs_ui", experiment or "default")
def checkpoint_map():
return {item["label"]: item for item in discover_checkpoints()}
def default_checkpoint_label():
cps = discover_checkpoints()
if not cps:
return None
for item in cps:
if item["path"].endswith("ablation_results/T4/best_model.pt"):
return item["label"]
return cps[0]["label"]
def infer_model_type(experiment_name: str, root: str = "") -> str:
if root == "ablation_results":
return "d3pm_cross_attention"
if experiment_name.startswith("d3pm_cross_attention"):
return "d3pm_cross_attention"
if experiment_name.startswith("d3pm_encoder_decoder"):
return "d3pm_encoder_decoder"
if experiment_name.startswith("baseline_cross_attention"):
return "baseline_cross_attention"
if experiment_name.startswith("baseline_encoder_decoder"):
return "baseline_encoder_decoder"
return CONFIG["model_type"]
def infer_include_negative(experiment_name: str, root: str = "") -> bool:
if root == "ablation_results":
return False
if "_neg_True" in experiment_name:
return True
if "_neg_False" in experiment_name:
return False
return CONFIG["data"]["include_negative_examples"]
def build_runtime_cfg(ckpt_path: str):
experiment = os.path.basename(os.path.dirname(ckpt_path))
root = os.path.basename(os.path.dirname(os.path.dirname(ckpt_path)))
cfg = copy.deepcopy(CONFIG)
cfg["model_type"] = infer_model_type(experiment, root=root)
cfg["data"]["include_negative_examples"] = infer_include_negative(experiment, root=root)
if root == "ablation_results" and experiment.startswith("T") and experiment[1:].isdigit():
t_val = int(experiment[1:])
cfg["model"]["diffusion_steps"] = t_val
cfg["inference"]["num_steps"] = t_val
device = _resolve_device(cfg.get("training", {}).get("device", "cpu"))
return cfg, device, experiment
def _build_tokenizers(cfg):
src_tok = SanskritSourceTokenizer(
vocab_size=cfg["model"].get("src_vocab_size", 16000),
max_len=cfg["model"]["max_seq_len"],
)
tgt_tok = SanskritTargetTokenizer(
vocab_size=cfg["model"].get("tgt_vocab_size", 16000),
max_len=cfg["model"]["max_seq_len"],
)
return src_tok, tgt_tok
def load_selected_model(checkpoint_label):
mapping = checkpoint_map()
if not mapping:
raise gr.Error("No checkpoints found. Add models under ablation_results/ or results*/.")
if not checkpoint_label:
checkpoint_label = default_checkpoint_label()
if checkpoint_label not in mapping:
raise gr.Error("Selected checkpoint not found. Click refresh.")
ckpt_path = mapping[checkpoint_label]["path"]
cfg, device, experiment = build_runtime_cfg(ckpt_path)
model, cfg = load_model(ckpt_path, cfg, device)
src_tok, tgt_tok = _build_tokenizers(cfg)
bundle = {
"ckpt_path": ckpt_path,
"experiment": experiment,
"device": str(device),
"cfg": cfg,
"model": model,
"src_tok": src_tok,
"tgt_tok": tgt_tok,
}
model_info = {
"checkpoint": ckpt_path,
"experiment": experiment,
"model_type": cfg["model_type"],
"include_negatives": cfg["data"]["include_negative_examples"],
"device": str(device),
"max_seq_len": cfg["model"]["max_seq_len"],
"diffusion_steps": cfg["model"]["diffusion_steps"],
"inference_steps": cfg["inference"]["num_steps"],
"d_model": cfg["model"]["d_model"],
"n_layers": cfg["model"]["n_layers"],
"n_heads": cfg["model"]["n_heads"],
}
status = f"Loaded `{experiment}` on `{device}` (`{cfg['model_type']}`)"
suggested_out = _guess_analysis_dir(experiment, ckpt_path)
return bundle, status, model_info, cfg["inference"]["num_steps"], suggested_out
def apply_preset(preset_name):
presets = {
"Manual": (0.70, 40, 1.20, 0.0),
"Literal": (0.60, 20, 1.25, 0.0),
"Balanced": (0.70, 40, 1.20, 0.0),
"Creative": (0.90, 80, 1.05, 0.2),
}
return presets.get(preset_name, presets["Balanced"])
def clean_generated_text(text: str, max_consecutive: int = 2) -> str:
text = " ".join(text.split())
if not text:
return text
tokens = text.split()
cleaned = []
prev = None
run = 0
for tok in tokens:
if tok == prev:
run += 1
else:
prev = tok
run = 1
if run <= max_consecutive:
cleaned.append(tok)
out = " ".join(cleaned).replace(" ।", "।").replace(" ॥", "॥")
return " ".join(out.split())
def save_generation(experiment, record):
ts = datetime.now().strftime("%Y%m%d")
path = os.path.join(RESULTS_DIR, f"{experiment}_ui_{ts}.json")
existing = []
if os.path.exists(path):
with open(path, "r", encoding="utf-8") as f:
existing = json.load(f)
existing.append(record)
with open(path, "w", encoding="utf-8") as f:
json.dump(existing, f, ensure_ascii=False, indent=2)
return path
def generate_from_ui(
model_bundle,
input_text,
temperature,
top_k,
repetition_penalty,
diversity_penalty,
num_steps,
clean_output,
):
if not model_bundle:
raise gr.Error("Load a model first.")
if not input_text.strip():
raise gr.Error("Enter input text first.")
cfg = copy.deepcopy(model_bundle["cfg"])
cfg["inference"]["temperature"] = float(temperature)
cfg["inference"]["top_k"] = int(top_k)
cfg["inference"]["repetition_penalty"] = float(repetition_penalty)
cfg["inference"]["diversity_penalty"] = float(diversity_penalty)
cfg["inference"]["num_steps"] = int(num_steps)
src_tok = model_bundle["src_tok"]
tgt_tok = model_bundle["tgt_tok"]
device = torch.device(model_bundle["device"])
input_ids = torch.tensor(
[src_tok.encode(input_text.strip())[:cfg["model"]["max_seq_len"]]],
dtype=torch.long,
device=device,
)
out = run_inference(model_bundle["model"], input_ids, cfg)
# Use the exact inference decode/cleanup logic for parity with inference.py
raw_output_text = _decode_clean(tgt_tok, out[0].tolist())
if clean_output:
output_text = _decode_with_cleanup(
tgt_tok, out[0].tolist(), input_text.strip(), cfg["inference"]
)
else:
output_text = raw_output_text
if not output_text:
output_text = "(empty output)"
record = {
"timestamp": datetime.now().isoformat(timespec="seconds"),
"experiment": model_bundle["experiment"],
"checkpoint": model_bundle["ckpt_path"],
"input_text": input_text,
"raw_output_text": raw_output_text,
"output_text": output_text,
"temperature": float(temperature),
"top_k": int(top_k),
"repetition_penalty": float(repetition_penalty),
"diversity_penalty": float(diversity_penalty),
"num_steps": int(num_steps),
"clean_output": bool(clean_output),
}
log_path = save_generation(model_bundle["experiment"], record)
status = f"Inference done. Saved: `{log_path}`"
return output_text, status, record
def _run_analysis_cmd(task, ckpt_path, output_dir, input_text="dharmo rakṣati rakṣitaḥ", phase="analyze"):
os.makedirs(output_dir, exist_ok=True)
script = Path("analysis") / "run_analysis.py"
if not script.exists():
return 2, "Analysis runner missing in Space image. Falling back to bundled analysis outputs."
# Space-safe Task4 fallback: if ablation models don't exist, bootstrap them
# from currently selected checkpoint so Task4 can still execute end-to-end.
if str(task) == "4" and phase == "analyze":
for t in (4, 8, 16, 32, 64):
t_dir = Path("ablation_results") / f"T{t}"
t_dir.mkdir(parents=True, exist_ok=True)
dst = t_dir / "best_model.pt"
if not dst.exists():
try:
os.symlink(os.path.abspath(ckpt_path), str(dst))
except Exception:
import shutil
shutil.copy2(ckpt_path, str(dst))
cmd = [
sys.executable,
str(script),
"--task",
str(task),
"--checkpoint",
ckpt_path,
"--output_dir",
output_dir,
]
if str(task) == "2" or str(task) == "all":
cmd.extend(["--input", input_text])
if str(task) == "4":
cmd.extend(["--phase", phase])
env = os.environ.copy()
env.setdefault("HF_HOME", "/tmp/hf_home")
env.setdefault("HF_DATASETS_CACHE", "/tmp/hf_datasets")
env.setdefault("HF_HUB_CACHE", "/tmp/hf_hub")
proc = subprocess.run(cmd, capture_output=True, text=True, env=env)
log = f"$ {' '.join(cmd)}\n\n{proc.stdout}\n{proc.stderr}"
return proc.returncode, log
def _bundle_task_outputs(model_bundle, output_dir):
src_dir = _guess_analysis_dir(model_bundle.get("experiment", ""), model_bundle.get("ckpt_path", ""))
if not os.path.isdir(src_dir):
return
os.makedirs(output_dir, exist_ok=True)
for name in os.listdir(src_dir):
src = os.path.join(src_dir, name)
dst = os.path.join(output_dir, name)
if os.path.isfile(src):
shutil.copy2(src, dst)
def _live_input_summary(model_bundle, input_text: str) -> str:
if not input_text.strip():
return "No input text provided."
cfg = copy.deepcopy(model_bundle["cfg"])
src_tok = model_bundle["src_tok"]
tgt_tok = model_bundle["tgt_tok"]
device = torch.device(model_bundle["device"])
inp = torch.tensor([src_tok.encode(input_text.strip())[:cfg["model"]["max_seq_len"]]], dtype=torch.long, device=device)
out = run_inference(model_bundle["model"], inp, cfg)
pred = _decode_with_cleanup(tgt_tok, out[0].tolist(), input_text.strip(), cfg["inference"])
toks = pred.split()
uniq = len(set(toks)) / max(1, len(toks))
return (
f"Live input: {input_text}\n"
f"Prediction: {pred}\n"
f"Length(tokens): {len(toks)}\n"
f"Unique-token ratio: {uniq:.3f}"
)
def run_single_task(model_bundle, task, output_dir, input_text, task4_phase):
if not model_bundle:
raise gr.Error("Load a model first.")
code, log = _run_analysis_cmd(task, model_bundle["ckpt_path"], output_dir, input_text, task4_phase)
if code != 0:
_bundle_task_outputs(model_bundle, output_dir)
log = f"{log}\n\n--- Live input summary ---\n{_live_input_summary(model_bundle, input_text)}"
status = f"Task {task} fallback mode: bundled reports + live input analysis."
else:
status = f"Task {task} completed (exit={code})."
return status, log
def run_all_tasks(model_bundle, output_dir, input_text, task4_phase):
if not model_bundle:
raise gr.Error("Load a model first.")
logs = []
failures = 0
for task in ["1", "2", "3", "4", "5"]:
code, log = _run_analysis_cmd(task, model_bundle["ckpt_path"], output_dir, input_text, task4_phase)
logs.append(f"\n\n{'='*22} TASK {task} {'='*22}\n{log}")
if code != 0:
failures += 1
if failures:
_bundle_task_outputs(model_bundle, output_dir)
logs.append(f"\n\n--- Live input summary ---\n{_live_input_summary(model_bundle, input_text)}")
status = f"Run-all finished with {failures} fallback task(s)." if failures else "All 5 tasks completed."
return status, "".join(logs)
def _read_text(path):
if not os.path.exists(path):
return "Not found."
with open(path, "r", encoding="utf-8", errors="ignore") as f:
return f.read()
def _img_or_none(path):
return path if os.path.exists(path) else None
def refresh_task_outputs(output_dir):
task1_txt = _read_text(os.path.join(output_dir, "task1_kv_cache.txt"))
task2_txt = _read_text(os.path.join(output_dir, "task2_report.txt"))
task3_txt = _read_text(os.path.join(output_dir, "task3_report.txt"))
task5_txt = _read_text(os.path.join(output_dir, "task5_report.txt"))
task2_drift = _img_or_none(os.path.join(output_dir, "task2_semantic_drift.png"))
task2_attn = _img_or_none(os.path.join(output_dir, "task2_attn_t0.png"))
task3_space = _img_or_none(os.path.join(output_dir, "task3_concept_space.png"))
task4_plot = _img_or_none(os.path.join(output_dir, "task4_ablation_3d.png"))
if task4_plot is None:
task4_plot = _img_or_none(os.path.join(output_dir, "task4_3d.png"))
return task1_txt, task2_txt, task2_drift, task2_attn, task3_txt, task3_space, task5_txt, task4_plot
CUSTOM_CSS = """
:root {
--bg1: #f5fbff;
--bg2: #f2f7ef;
--card: #ffffff;
--line: #d9e6f2;
--ink: #163048;
}
.gradio-container {
background: linear-gradient(130deg, var(--bg1), var(--bg2));
color: var(--ink);
}
#hero {
background: radial-gradient(110% 130% at 0% 0%, #d7ebff 0%, #ecf6ff 55%, #f8fbff 100%);
border: 1px solid #cfe0f1;
border-radius: 16px;
padding: 18px 20px;
}
.panel {
background: var(--card);
border: 1px solid var(--line);
border-radius: 14px;
}
"""
with gr.Blocks(title="Sanskrit Diffusion Client Demo", css=CUSTOM_CSS) as demo:
model_state = gr.State(None)
gr.Markdown(
"""
<div id="hero">
<h1 style="margin:0;">Sanskrit Diffusion Client Demo</h1>
<p style="margin:.5rem 0 0 0;">
Select any trained model, run all 5 analysis tasks or individual tasks, then test inference with user-controlled parameters.
</p>
</div>
"""
)
with gr.Row():
with gr.Column(scale=2, elem_classes=["panel"]):
checkpoint_dropdown = gr.Dropdown(
label="Model Checkpoint",
choices=list(checkpoint_map().keys()),
value=default_checkpoint_label(),
interactive=True,
)
with gr.Column(scale=1, elem_classes=["panel"]):
refresh_btn = gr.Button("Refresh Models")
load_btn = gr.Button("Load Selected Model", variant="primary")
init_msg = "Select a model and load." if checkpoint_map() else "No checkpoints found in ablation_results/ or results*/."
load_status = gr.Markdown(init_msg)
model_info = gr.JSON(label="Loaded Model Details")
with gr.Tabs():
with gr.Tab("1) Task Runner"):
with gr.Row():
with gr.Column(scale=2):
analysis_output_dir = gr.Textbox(
label="Analysis Output Directory",
value=DEFAULT_ANALYSIS_OUT,
)
analysis_input = gr.Textbox(
label="Task 2 Input Text",
value="dharmo rakṣati rakṣitaḥ",
lines=2,
)
with gr.Column(scale=1):
task4_phase = gr.Dropdown(
choices=["analyze", "generate_configs"],
value="analyze",
label="Task 4 Phase",
)
run_all_btn = gr.Button("Run All 5 Tasks", variant="primary")
with gr.Row():
task_choice = gr.Dropdown(
choices=["1", "2", "3", "4", "5"],
value="1",
label="Single Task",
)
run_single_btn = gr.Button("Run Selected Task")
refresh_outputs_btn = gr.Button("Refresh Output Viewer")
task_run_status = gr.Markdown("")
task_run_log = gr.Textbox(label="Task Execution Log", lines=18, interactive=False)
with gr.Accordion("Task Outputs Viewer", open=True):
task1_box = gr.Textbox(label="Task 1 Report", lines=10, interactive=False)
task2_box = gr.Textbox(label="Task 2 Report", lines=10, interactive=False)
with gr.Row():
task2_drift_img = gr.Image(label="Task2 Drift", type="filepath")
task2_attn_img = gr.Image(label="Task2 Attention", type="filepath")
task3_box = gr.Textbox(label="Task 3 Report", lines=10, interactive=False)
task3_img = gr.Image(label="Task3 Concept Space", type="filepath")
task5_box = gr.Textbox(label="Task 5 Report", lines=10, interactive=False)
task4_img = gr.Image(label="Task4 3D Ablation Plot", type="filepath")
with gr.Tab("2) Inference Playground"):
with gr.Row():
with gr.Column(scale=2):
input_text = gr.Textbox(
label="Input (Roman / IAST)",
lines=4,
value="dharmo rakṣati rakṣitaḥ",
)
output_text = gr.Textbox(
label="Output (Devanagari)",
lines=7,
interactive=False,
)
run_status = gr.Markdown("")
run_record = gr.JSON(label="Inference Metadata")
with gr.Column(scale=1, elem_classes=["panel"]):
preset = gr.Radio(["Manual", "Literal", "Balanced", "Creative"], value="Balanced", label="Preset")
temperature = gr.Slider(0.4, 1.2, value=0.70, step=0.05, label="Temperature")
top_k = gr.Slider(5, 100, value=40, step=1, label="Top-K")
repetition_penalty = gr.Slider(1.0, 3.0, value=1.20, step=0.05, label="Repetition Penalty")
diversity_penalty = gr.Slider(0.0, 1.0, value=0.0, step=0.05, label="Diversity Penalty")
num_steps = gr.Slider(1, 128, value=64, step=1, label="Inference Steps")
clean_output = gr.Checkbox(value=True, label="Clean Output")
generate_btn = gr.Button("Generate", variant="primary")
gr.Examples(
examples=[
["dharmo rakṣati rakṣitaḥ"],
["satyameva jayate"],
["yadā mano nivarteta viṣayebhyaḥ svabhāvataḥ"],
],
inputs=[input_text],
)
def refresh_checkpoints():
choices = list(checkpoint_map().keys())
value = default_checkpoint_label() if choices else None
msg = f"Found {len(choices)} checkpoint(s)." if choices else "No checkpoints found."
return gr.Dropdown(choices=choices, value=value), msg
def auto_load_default():
choices = list(checkpoint_map().keys())
if not choices:
return None, "No checkpoints found.", {}, 64, DEFAULT_ANALYSIS_OUT
return load_selected_model(default_checkpoint_label())
refresh_btn.click(fn=refresh_checkpoints, outputs=[checkpoint_dropdown, load_status])
load_btn.click(
fn=load_selected_model,
inputs=[checkpoint_dropdown],
outputs=[model_state, load_status, model_info, num_steps, analysis_output_dir],
)
preset.change(
fn=apply_preset,
inputs=[preset],
outputs=[temperature, top_k, repetition_penalty, diversity_penalty],
)
generate_btn.click(
fn=generate_from_ui,
inputs=[
model_state,
input_text,
temperature,
top_k,
repetition_penalty,
diversity_penalty,
num_steps,
clean_output,
],
outputs=[output_text, run_status, run_record],
)
input_text.submit(
fn=generate_from_ui,
inputs=[
model_state,
input_text,
temperature,
top_k,
repetition_penalty,
diversity_penalty,
num_steps,
clean_output,
],
outputs=[output_text, run_status, run_record],
)
run_single_btn.click(
fn=run_single_task,
inputs=[model_state, task_choice, analysis_output_dir, analysis_input, task4_phase],
outputs=[task_run_status, task_run_log],
)
run_all_btn.click(
fn=run_all_tasks,
inputs=[model_state, analysis_output_dir, analysis_input, task4_phase],
outputs=[task_run_status, task_run_log],
)
refresh_outputs_btn.click(
fn=refresh_task_outputs,
inputs=[analysis_output_dir],
outputs=[
task1_box,
task2_box,
task2_drift_img,
task2_attn_img,
task3_box,
task3_img,
task5_box,
task4_img,
],
)
demo.load(
fn=auto_load_default,
outputs=[model_state, load_status, model_info, num_steps, analysis_output_dir],
)
demo.load(
fn=refresh_task_outputs,
inputs=[analysis_output_dir],
outputs=[
task1_box,
task2_box,
task2_drift_img,
task2_attn_img,
task3_box,
task3_img,
task5_box,
task4_img,
],
)
if __name__ == "__main__":
port = int(os.environ["GRADIO_SERVER_PORT"]) if "GRADIO_SERVER_PORT" in os.environ else None
demo.launch(server_name="0.0.0.0", server_port=port, share=False)