Spaces:
Sleeping
Sleeping
| from pathlib import Path | |
| import gradio as gr | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
| ROOT = Path(__file__).parent | |
| MODEL_ROOT = ROOT / "custom_t5_enzh" | |
| PREFIX = "translate English to Chinese: " | |
| def latest_checkpoint(root: Path) -> Path | None: | |
| if not root.exists(): | |
| return None | |
| ckpts = [p for p in root.iterdir() if p.is_dir() and p.name.startswith("checkpoint-")] | |
| if ckpts: | |
| ckpts.sort(key=lambda p: int(p.name.split("-")[-1])) | |
| return ckpts[-1] | |
| # fallback: root itself contains model files | |
| if (root / "config.json").exists() or (root / "model.safetensors").exists(): | |
| return root | |
| return None | |
| CKPT = latest_checkpoint(MODEL_ROOT) | |
| _pipe = {"tok": None, "model": None, "device": None, "ckpt": None} | |
| def model_ready(): | |
| return CKPT is not None | |
| def get_model(): | |
| if _pipe["model"] is None: | |
| if CKPT is None: | |
| raise RuntimeError("No checkpoint found.") | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # β let HF choose the right tokenizer implementation | |
| tok = AutoTokenizer.from_pretrained(CKPT) | |
| model = AutoModelForSeq2SeqLM.from_pretrained(CKPT).to(device) | |
| model.eval() | |
| _pipe.update({"tok": tok, "model": model, "device": device, "ckpt": str(CKPT)}) | |
| # quick startup self-test | |
| try: | |
| test = translate_text("I am happy") | |
| print(f"[LOAD OK] ckpt={CKPT} device={device} test_out={repr(test)}") | |
| except Exception as e: | |
| print(f"[LOAD FAIL] ckpt={CKPT} device={device} err={e}") | |
| raise | |
| return _pipe["tok"], _pipe["model"], _pipe["device"] | |
| def translate_text(text: str): | |
| text = (text or "").strip() | |
| if not text: | |
| return "" | |
| if not model_ready(): | |
| return "[Model not ready β checkpoint folder not found.]" | |
| tok, model, device = get_model() | |
| prompt = PREFIX + text | |
| inputs = tok(prompt, return_tensors="pt", truncation=True, max_length=256).to(device) | |
| with torch.no_grad(): | |
| out_ids = model.generate( | |
| **inputs, | |
| max_new_tokens=80, | |
| num_beams=4, | |
| early_stopping=True, | |
| ) | |
| result = tok.decode(out_ids[0], skip_special_tokens=True).strip() | |
| # extra debug if blank | |
| if not result: | |
| raw = tok.decode(out_ids[0]) | |
| return f"[Blank output] raw={raw!r} ckpt={_pipe['ckpt']} device={device}" | |
| return result | |
| title = "EN β ZH Translator" | |
| status = "β Model found" if model_ready() else "β³ Model not found" | |
| with gr.Blocks() as demo: | |
| gr.Markdown(f"# {title}") | |
| gr.Markdown(f"**Status:** {status}") | |
| gr.Markdown(f"**Loaded path:** `{str(CKPT) if CKPT else 'None'}`") | |
| inp = gr.Textbox(label="English", lines=4, placeholder="Type English here...") | |
| out = gr.Textbox(label="Chinese", lines=4) | |
| btn = gr.Button("Translate") | |
| btn.click(translate_text, inp, out) | |
| demo.launch() | |