from pathlib import Path import gradio as gr import torch from transformers import AutoTokenizer, AutoModelForSeq2SeqLM ROOT = Path(__file__).parent MODEL_ROOT = ROOT / "custom_t5_enzh" PREFIX = "translate English to Chinese: " def latest_checkpoint(root: Path) -> Path | None: if not root.exists(): return None ckpts = [p for p in root.iterdir() if p.is_dir() and p.name.startswith("checkpoint-")] if ckpts: ckpts.sort(key=lambda p: int(p.name.split("-")[-1])) return ckpts[-1] # fallback: root itself contains model files if (root / "config.json").exists() or (root / "model.safetensors").exists(): return root return None CKPT = latest_checkpoint(MODEL_ROOT) _pipe = {"tok": None, "model": None, "device": None, "ckpt": None} def model_ready(): return CKPT is not None def get_model(): if _pipe["model"] is None: if CKPT is None: raise RuntimeError("No checkpoint found.") device = "cuda" if torch.cuda.is_available() else "cpu" # ✅ let HF choose the right tokenizer implementation tok = AutoTokenizer.from_pretrained(CKPT) model = AutoModelForSeq2SeqLM.from_pretrained(CKPT).to(device) model.eval() _pipe.update({"tok": tok, "model": model, "device": device, "ckpt": str(CKPT)}) # quick startup self-test try: test = translate_text("I am happy") print(f"[LOAD OK] ckpt={CKPT} device={device} test_out={repr(test)}") except Exception as e: print(f"[LOAD FAIL] ckpt={CKPT} device={device} err={e}") raise return _pipe["tok"], _pipe["model"], _pipe["device"] def translate_text(text: str): text = (text or "").strip() if not text: return "" if not model_ready(): return "[Model not ready — checkpoint folder not found.]" tok, model, device = get_model() prompt = PREFIX + text inputs = tok(prompt, return_tensors="pt", truncation=True, max_length=256).to(device) with torch.no_grad(): out_ids = model.generate( **inputs, max_new_tokens=80, num_beams=4, early_stopping=True, ) result = tok.decode(out_ids[0], skip_special_tokens=True).strip() # extra debug if blank if not result: raw = tok.decode(out_ids[0]) return f"[Blank output] raw={raw!r} ckpt={_pipe['ckpt']} device={device}" return result title = "EN → ZH Translator" status = "✅ Model found" if model_ready() else "⏳ Model not found" with gr.Blocks() as demo: gr.Markdown(f"# {title}") gr.Markdown(f"**Status:** {status}") gr.Markdown(f"**Loaded path:** `{str(CKPT) if CKPT else 'None'}`") inp = gr.Textbox(label="English", lines=4, placeholder="Type English here...") out = gr.Textbox(label="Chinese", lines=4) btn = gr.Button("Translate") btn.click(translate_text, inp, out) demo.launch()