Spaces:
Sleeping
Sleeping
| """ | |
| Hugging Face Space app for Sanskrit D3PM project. | |
| Deploy on Spaces with: | |
| app_file = app_hf_space.py | |
| Optional environment variables: | |
| HF_CHECKPOINT_REPO : model repo id (e.g. "username/sanskrit-d3pm") | |
| HF_CHECKPOINT_FILE : checkpoint path in repo (default: "best_model.pt") | |
| HF_CHECKPOINT_LABEL : UI label for remote checkpoint | |
| """ | |
| from __future__ import annotations | |
| import copy | |
| import os | |
| from typing import Dict, Tuple | |
| import gradio as gr | |
| import torch | |
| from config import CONFIG | |
| from inference import _build_tokenizers, _resolve_device, load_model, run_inference | |
| def _clean_output(text: str, max_repeat: int = 2) -> str: | |
| text = " ".join(text.split()) | |
| if not text: | |
| return text | |
| toks = text.split() | |
| out = [] | |
| prev = None | |
| run = 0 | |
| for t in toks: | |
| if t == prev: | |
| run += 1 | |
| else: | |
| prev = t | |
| run = 1 | |
| if run <= max_repeat: | |
| out.append(t) | |
| s = " ".join(out) | |
| s = s.replace(" ।", "।").replace(" ॥", "॥") | |
| return " ".join(s.split()) | |
| def _discover_local_checkpoints() -> Dict[str, str]: | |
| found = {} | |
| for root in ("ablation_results", "results7", "results"): | |
| if not os.path.isdir(root): | |
| continue | |
| for exp in sorted(os.listdir(root)): | |
| ckpt = os.path.join(root, exp, "best_model.pt") | |
| if os.path.exists(ckpt): | |
| found[f"{exp} [{root}]"] = ckpt | |
| return found | |
| def _discover_remote_checkpoint() -> Dict[str, str]: | |
| repo = os.getenv("HF_CHECKPOINT_REPO", "").strip() | |
| if not repo: | |
| return {} | |
| filename = os.getenv("HF_CHECKPOINT_FILE", "best_model.pt").strip() | |
| label = os.getenv("HF_CHECKPOINT_LABEL", f"remote:{repo}") | |
| try: | |
| from huggingface_hub import hf_hub_download | |
| ckpt_path = hf_hub_download(repo_id=repo, filename=filename) | |
| return {label: ckpt_path} | |
| except Exception as e: | |
| print(f"[WARN] remote checkpoint download failed: {e}") | |
| return {} | |
| def _infer_model_type(path: str) -> str: | |
| p = path.lower() | |
| if "d3pm_encoder_decoder" in p: | |
| return "d3pm_encoder_decoder" | |
| if "baseline_cross_attention" in p: | |
| return "baseline_cross_attention" | |
| if "baseline_encoder_decoder" in p: | |
| return "baseline_encoder_decoder" | |
| return "d3pm_cross_attention" | |
| def _infer_neg(path: str) -> bool: | |
| p = path.lower() | |
| if "_neg_true" in p: | |
| return True | |
| if "_neg_false" in p: | |
| return False | |
| return CONFIG["data"]["include_negative_examples"] | |
| class RuntimeStore: | |
| def __init__(self): | |
| self.loaded: Dict[str, Dict] = {} | |
| def get(self, ckpt_label: str, ckpt_path: str) -> Dict: | |
| if ckpt_label in self.loaded: | |
| return self.loaded[ckpt_label] | |
| cfg = copy.deepcopy(CONFIG) | |
| cfg["model_type"] = _infer_model_type(ckpt_path) | |
| cfg["data"]["include_negative_examples"] = _infer_neg(ckpt_path) | |
| device = _resolve_device(cfg) | |
| model, cfg = load_model(ckpt_path, cfg, device) | |
| src_tok, tgt_tok = _build_tokenizers(cfg) | |
| bundle = { | |
| "label": ckpt_label, | |
| "path": ckpt_path, | |
| "cfg": cfg, | |
| "device": str(device), | |
| "model": model, | |
| "src_tok": src_tok, | |
| "tgt_tok": tgt_tok, | |
| } | |
| self.loaded[ckpt_label] = bundle | |
| return bundle | |
| RUNTIME = RuntimeStore() | |
| CHECKPOINTS = {} | |
| CHECKPOINTS.update(_discover_local_checkpoints()) | |
| CHECKPOINTS.update(_discover_remote_checkpoint()) | |
| if not CHECKPOINTS: | |
| CHECKPOINTS = {"No checkpoint found": ""} | |
| def load_checkpoint_ui(label: str) -> Tuple[Dict, str]: | |
| if label not in CHECKPOINTS or not CHECKPOINTS[label]: | |
| raise gr.Error("No valid checkpoint found. Upload/provide best_model.pt first.") | |
| bundle = RUNTIME.get(label, CHECKPOINTS[label]) | |
| info = ( | |
| f"Loaded `{label}`\n" | |
| f"- path: `{bundle['path']}`\n" | |
| f"- model_type: `{bundle['cfg']['model_type']}`\n" | |
| f"- device: `{bundle['device']}`\n" | |
| f"- max_seq_len: `{bundle['cfg']['model']['max_seq_len']}`" | |
| ) | |
| return bundle, info | |
| def generate_ui( | |
| bundle: Dict, | |
| text: str, | |
| temperature: float, | |
| top_k: int, | |
| repetition_penalty: float, | |
| diversity_penalty: float, | |
| num_steps: int, | |
| clean_output: bool, | |
| ) -> str: | |
| if not bundle: | |
| raise gr.Error("Load a checkpoint first.") | |
| if not text.strip(): | |
| raise gr.Error("Enter input text.") | |
| cfg = copy.deepcopy(bundle["cfg"]) | |
| cfg["inference"]["temperature"] = float(temperature) | |
| cfg["inference"]["top_k"] = int(top_k) | |
| cfg["inference"]["repetition_penalty"] = float(repetition_penalty) | |
| cfg["inference"]["diversity_penalty"] = float(diversity_penalty) | |
| cfg["inference"]["num_steps"] = int(num_steps) | |
| src_tok = bundle["src_tok"] | |
| tgt_tok = bundle["tgt_tok"] | |
| device = torch.device(bundle["device"]) | |
| ids = torch.tensor([src_tok.encode(text.strip())], dtype=torch.long, device=device) | |
| out = run_inference(bundle["model"], ids, cfg) | |
| token_ids = [x for x in out[0].tolist() if x > 4] | |
| pred = tgt_tok.decode(token_ids).strip() | |
| if clean_output: | |
| pred = _clean_output(pred) | |
| return pred if pred else "(empty output)" | |
| with gr.Blocks(title="Sanskrit D3PM Space") as demo: | |
| model_state = gr.State(None) | |
| gr.Markdown( | |
| """ | |
| ## Sanskrit D3PM Paraphrase (IAST → Devanagari) | |
| Load a trained checkpoint and generate output from Roman/IAST Sanskrit input. | |
| """ | |
| ) | |
| checkpoint = gr.Dropdown( | |
| choices=list(CHECKPOINTS.keys()), | |
| value=list(CHECKPOINTS.keys())[0], | |
| label="Checkpoint", | |
| ) | |
| load_btn = gr.Button("Load Model", variant="primary") | |
| load_info = gr.Markdown("Select a checkpoint and click **Load Model**.") | |
| text_in = gr.Textbox(label="Input (Roman / IAST)", lines=3, value="dharmo rakṣati rakṣitaḥ") | |
| text_out = gr.Textbox(label="Output (Devanagari)", lines=6) | |
| with gr.Row(): | |
| temperature = gr.Slider(0.4, 1.2, value=0.70, step=0.05, label="Temperature") | |
| top_k = gr.Slider(5, 100, value=40, step=1, label="Top-K") | |
| repetition_penalty = gr.Slider(1.0, 3.0, value=1.20, step=0.05, label="Repetition Penalty") | |
| diversity_penalty = gr.Slider(0.0, 1.0, value=0.0, step=0.05, label="Diversity Penalty") | |
| num_steps = gr.Slider(1, 128, value=64, step=1, label="Inference Steps") | |
| clean_output = gr.Checkbox(value=True, label="Clean Output") | |
| generate_btn = gr.Button("Generate", variant="primary") | |
| load_btn.click(load_checkpoint_ui, inputs=[checkpoint], outputs=[model_state, load_info]) | |
| generate_btn.click( | |
| generate_ui, | |
| inputs=[ | |
| model_state, text_in, temperature, top_k, repetition_penalty, | |
| diversity_penalty, num_steps, clean_output | |
| ], | |
| outputs=[text_out], | |
| ) | |
| text_in.submit( | |
| generate_ui, | |
| inputs=[ | |
| model_state, text_in, temperature, top_k, repetition_penalty, | |
| diversity_penalty, num_steps, clean_output | |
| ], | |
| outputs=[text_out], | |
| ) | |
| if __name__ == "__main__": | |
| port = int(os.environ.get("GRADIO_SERVER_PORT", "7860")) | |
| demo.launch(server_name="0.0.0.0", server_port=port, share=False) | |