File size: 2,963 Bytes
8ccad0e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
from pathlib import Path
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

ROOT = Path(__file__).parent
MODEL_ROOT = ROOT / "custom_t5_enzh"
PREFIX = "translate English to Chinese: "

def latest_checkpoint(root: Path) -> Path | None:
    if not root.exists():
        return None

    ckpts = [p for p in root.iterdir() if p.is_dir() and p.name.startswith("checkpoint-")]
    if ckpts:
        ckpts.sort(key=lambda p: int(p.name.split("-")[-1]))
        return ckpts[-1]

    # fallback: root itself contains model files
    if (root / "config.json").exists() or (root / "model.safetensors").exists():
        return root

    return None

CKPT = latest_checkpoint(MODEL_ROOT)

_pipe = {"tok": None, "model": None, "device": None, "ckpt": None}

def model_ready():
    return CKPT is not None

def get_model():
    if _pipe["model"] is None:
        if CKPT is None:
            raise RuntimeError("No checkpoint found.")

        device = "cuda" if torch.cuda.is_available() else "cpu"

        # ✅ let HF choose the right tokenizer implementation
        tok = AutoTokenizer.from_pretrained(CKPT)
        model = AutoModelForSeq2SeqLM.from_pretrained(CKPT).to(device)
        model.eval()

        _pipe.update({"tok": tok, "model": model, "device": device, "ckpt": str(CKPT)})

        # quick startup self-test
        try:
            test = translate_text("I am happy")
            print(f"[LOAD OK] ckpt={CKPT} device={device} test_out={repr(test)}")
        except Exception as e:
            print(f"[LOAD FAIL] ckpt={CKPT} device={device} err={e}")
            raise

    return _pipe["tok"], _pipe["model"], _pipe["device"]

def translate_text(text: str):
    text = (text or "").strip()
    if not text:
        return ""
    if not model_ready():
        return "[Model not ready — checkpoint folder not found.]"

    tok, model, device = get_model()

    prompt = PREFIX + text
    inputs = tok(prompt, return_tensors="pt", truncation=True, max_length=256).to(device)

    with torch.no_grad():
        out_ids = model.generate(
            **inputs,
            max_new_tokens=80,
            num_beams=4,
            early_stopping=True,
        )

    result = tok.decode(out_ids[0], skip_special_tokens=True).strip()

    # extra debug if blank
    if not result:
        raw = tok.decode(out_ids[0])
        return f"[Blank output] raw={raw!r} ckpt={_pipe['ckpt']} device={device}"

    return result

title = "EN → ZH Translator"
status = "✅ Model found" if model_ready() else "⏳ Model not found"

with gr.Blocks() as demo:
    gr.Markdown(f"# {title}")
    gr.Markdown(f"**Status:** {status}")
    gr.Markdown(f"**Loaded path:** `{str(CKPT) if CKPT else 'None'}`")

    inp = gr.Textbox(label="English", lines=4, placeholder="Type English here...")
    out = gr.Textbox(label="Chinese", lines=4)

    btn = gr.Button("Translate")
    btn.click(translate_text, inp, out)

demo.launch()