devflow / app.py
bhsinghgrid's picture
Update app/inference + ablation task outputs
27f26fd verified
raw
history blame
7.29 kB
"""
Hugging Face Space app for Sanskrit D3PM project.
Deploy on Spaces with:
app_file = app_hf_space.py
Optional environment variables:
HF_CHECKPOINT_REPO : model repo id (e.g. "username/sanskrit-d3pm")
HF_CHECKPOINT_FILE : checkpoint path in repo (default: "best_model.pt")
HF_CHECKPOINT_LABEL : UI label for remote checkpoint
"""
from __future__ import annotations
import copy
import os
from typing import Dict, Tuple
import gradio as gr
import torch
from config import CONFIG
from inference import _build_tokenizers, _resolve_device, load_model, run_inference
def _clean_output(text: str, max_repeat: int = 2) -> str:
text = " ".join(text.split())
if not text:
return text
toks = text.split()
out = []
prev = None
run = 0
for t in toks:
if t == prev:
run += 1
else:
prev = t
run = 1
if run <= max_repeat:
out.append(t)
s = " ".join(out)
s = s.replace(" ।", "।").replace(" ॥", "॥")
return " ".join(s.split())
def _discover_local_checkpoints() -> Dict[str, str]:
found = {}
for root in ("ablation_results", "results7", "results"):
if not os.path.isdir(root):
continue
for exp in sorted(os.listdir(root)):
ckpt = os.path.join(root, exp, "best_model.pt")
if os.path.exists(ckpt):
found[f"{exp} [{root}]"] = ckpt
return found
def _discover_remote_checkpoint() -> Dict[str, str]:
repo = os.getenv("HF_CHECKPOINT_REPO", "").strip()
if not repo:
return {}
filename = os.getenv("HF_CHECKPOINT_FILE", "best_model.pt").strip()
label = os.getenv("HF_CHECKPOINT_LABEL", f"remote:{repo}")
try:
from huggingface_hub import hf_hub_download
ckpt_path = hf_hub_download(repo_id=repo, filename=filename)
return {label: ckpt_path}
except Exception as e:
print(f"[WARN] remote checkpoint download failed: {e}")
return {}
def _infer_model_type(path: str) -> str:
p = path.lower()
if "d3pm_encoder_decoder" in p:
return "d3pm_encoder_decoder"
if "baseline_cross_attention" in p:
return "baseline_cross_attention"
if "baseline_encoder_decoder" in p:
return "baseline_encoder_decoder"
return "d3pm_cross_attention"
def _infer_neg(path: str) -> bool:
p = path.lower()
if "_neg_true" in p:
return True
if "_neg_false" in p:
return False
return CONFIG["data"]["include_negative_examples"]
class RuntimeStore:
def __init__(self):
self.loaded: Dict[str, Dict] = {}
def get(self, ckpt_label: str, ckpt_path: str) -> Dict:
if ckpt_label in self.loaded:
return self.loaded[ckpt_label]
cfg = copy.deepcopy(CONFIG)
cfg["model_type"] = _infer_model_type(ckpt_path)
cfg["data"]["include_negative_examples"] = _infer_neg(ckpt_path)
device = _resolve_device(cfg)
model, cfg = load_model(ckpt_path, cfg, device)
src_tok, tgt_tok = _build_tokenizers(cfg)
bundle = {
"label": ckpt_label,
"path": ckpt_path,
"cfg": cfg,
"device": str(device),
"model": model,
"src_tok": src_tok,
"tgt_tok": tgt_tok,
}
self.loaded[ckpt_label] = bundle
return bundle
RUNTIME = RuntimeStore()
CHECKPOINTS = {}
CHECKPOINTS.update(_discover_local_checkpoints())
CHECKPOINTS.update(_discover_remote_checkpoint())
if not CHECKPOINTS:
CHECKPOINTS = {"No checkpoint found": ""}
def load_checkpoint_ui(label: str) -> Tuple[Dict, str]:
if label not in CHECKPOINTS or not CHECKPOINTS[label]:
raise gr.Error("No valid checkpoint found. Upload/provide best_model.pt first.")
bundle = RUNTIME.get(label, CHECKPOINTS[label])
info = (
f"Loaded `{label}`\n"
f"- path: `{bundle['path']}`\n"
f"- model_type: `{bundle['cfg']['model_type']}`\n"
f"- device: `{bundle['device']}`\n"
f"- max_seq_len: `{bundle['cfg']['model']['max_seq_len']}`"
)
return bundle, info
def generate_ui(
bundle: Dict,
text: str,
temperature: float,
top_k: int,
repetition_penalty: float,
diversity_penalty: float,
num_steps: int,
clean_output: bool,
) -> str:
if not bundle:
raise gr.Error("Load a checkpoint first.")
if not text.strip():
raise gr.Error("Enter input text.")
cfg = copy.deepcopy(bundle["cfg"])
cfg["inference"]["temperature"] = float(temperature)
cfg["inference"]["top_k"] = int(top_k)
cfg["inference"]["repetition_penalty"] = float(repetition_penalty)
cfg["inference"]["diversity_penalty"] = float(diversity_penalty)
cfg["inference"]["num_steps"] = int(num_steps)
src_tok = bundle["src_tok"]
tgt_tok = bundle["tgt_tok"]
device = torch.device(bundle["device"])
ids = torch.tensor([src_tok.encode(text.strip())], dtype=torch.long, device=device)
out = run_inference(bundle["model"], ids, cfg)
token_ids = [x for x in out[0].tolist() if x > 4]
pred = tgt_tok.decode(token_ids).strip()
if clean_output:
pred = _clean_output(pred)
return pred if pred else "(empty output)"
with gr.Blocks(title="Sanskrit D3PM Space") as demo:
model_state = gr.State(None)
gr.Markdown(
"""
## Sanskrit D3PM Paraphrase (IAST → Devanagari)
Load a trained checkpoint and generate output from Roman/IAST Sanskrit input.
"""
)
checkpoint = gr.Dropdown(
choices=list(CHECKPOINTS.keys()),
value=list(CHECKPOINTS.keys())[0],
label="Checkpoint",
)
load_btn = gr.Button("Load Model", variant="primary")
load_info = gr.Markdown("Select a checkpoint and click **Load Model**.")
text_in = gr.Textbox(label="Input (Roman / IAST)", lines=3, value="dharmo rakṣati rakṣitaḥ")
text_out = gr.Textbox(label="Output (Devanagari)", lines=6)
with gr.Row():
temperature = gr.Slider(0.4, 1.2, value=0.70, step=0.05, label="Temperature")
top_k = gr.Slider(5, 100, value=40, step=1, label="Top-K")
repetition_penalty = gr.Slider(1.0, 3.0, value=1.20, step=0.05, label="Repetition Penalty")
diversity_penalty = gr.Slider(0.0, 1.0, value=0.0, step=0.05, label="Diversity Penalty")
num_steps = gr.Slider(1, 128, value=64, step=1, label="Inference Steps")
clean_output = gr.Checkbox(value=True, label="Clean Output")
generate_btn = gr.Button("Generate", variant="primary")
load_btn.click(load_checkpoint_ui, inputs=[checkpoint], outputs=[model_state, load_info])
generate_btn.click(
generate_ui,
inputs=[
model_state, text_in, temperature, top_k, repetition_penalty,
diversity_penalty, num_steps, clean_output
],
outputs=[text_out],
)
text_in.submit(
generate_ui,
inputs=[
model_state, text_in, temperature, top_k, repetition_penalty,
diversity_penalty, num_steps, clean_output
],
outputs=[text_out],
)
if __name__ == "__main__":
port = int(os.environ.get("GRADIO_SERVER_PORT", "7860"))
demo.launch(server_name="0.0.0.0", server_port=port, share=False)