EcoFriendlyWoodVerneer commited on
Commit
ef00eb2
·
1 Parent(s): a028332
Files changed (1) hide show
  1. app.py +77 -75
app.py CHANGED
@@ -6,7 +6,10 @@ from tokenizers import Tokenizer
6
 
7
  ROOT = Path(__file__).parent
8
 
 
9
  MODEL_ROOT = ROOT / "vMAX_512_8L_32k_OpenSubs"
 
 
10
  PREFIX = "translate English to Indonesian: "
11
 
12
 
@@ -19,7 +22,6 @@ def latest_checkpoint(root: Path) -> Path | None:
19
  ckpts.sort(key=lambda p: int(p.name.split("-")[-1]))
20
  return ckpts[-1]
21
 
22
- # fallback: root itself contains model files
23
  if (root / "config.json").exists() or (root / "model.safetensors").exists():
24
  return root
25
 
@@ -27,6 +29,9 @@ def latest_checkpoint(root: Path) -> Path | None:
27
 
28
 
29
  CKPT = latest_checkpoint(MODEL_ROOT)
 
 
 
30
  _pipe = {"tok": None, "model": None, "device": None, "ckpt": None}
31
 
32
 
@@ -36,15 +41,11 @@ def model_ready():
36
 
37
  class SimpleTokenizerWrapper:
38
  """
39
- Minimal tokenizer wrapper around `tokenizers.Tokenizer` that matches what this app needs:
40
- - callable returning PyTorch tensors (input_ids + attention_mask)
41
- - decode(ids)
42
- - eos_token_id attribute (optional)
43
  """
44
  def __init__(self, base_tok: Tokenizer):
45
  self.base_tok = base_tok
46
-
47
- # Try to discover EOS/PAD ids if they exist in the vocab; safe if missing
48
  self.eos_token_id = self._find_id(["</s>", "<eos>", "[EOS]"])
49
  self.pad_token_id = self._find_id(["<pad>", "[PAD]"])
50
 
@@ -60,14 +61,9 @@ class SimpleTokenizerWrapper:
60
 
61
  def __call__(self, text, return_tensors="pt", truncation=True, max_length=256):
62
  enc = self.base_tok.encode(text)
63
- ids = enc.ids
64
-
65
- if truncation and max_length is not None:
66
- ids = ids[:max_length]
67
-
68
  input_ids = torch.tensor([ids], dtype=torch.long)
69
  attention_mask = torch.ones_like(input_ids)
70
-
71
  return {"input_ids": input_ids, "attention_mask": attention_mask}
72
 
73
  def decode(self, ids, skip_special_tokens=True):
@@ -84,98 +80,104 @@ def load_tokenizer(ckpt: Path) -> SimpleTokenizerWrapper:
84
  return SimpleTokenizerWrapper(base)
85
 
86
 
 
 
 
 
 
 
 
 
 
 
87
  def get_model():
88
- if _pipe["model"] is None:
89
- if CKPT is None:
90
- raise RuntimeError(f"No checkpoint found under: {MODEL_ROOT}")
91
 
92
- device = "cuda" if torch.cuda.is_available() else "cpu"
 
93
 
94
- # bypass broken tokenizer_config.json (TokenizersBackend error on Spaces)
95
- tok = load_tokenizer(CKPT)
 
96
 
97
- model = AutoModelForSeq2SeqLM.from_pretrained(CKPT).to(device)
98
- model.eval()
99
 
100
- _pipe.update({"tok": tok, "model": model, "device": device, "ckpt": str(CKPT)})
 
 
 
101
 
102
- # quick startup self-test (will run on first translate click)
103
- try:
104
- test = translate_text("I am happy")
105
- print(f"[LOAD OK] ckpt={CKPT} device={device} test_out={repr(test)}")
106
- except Exception as e:
107
- print(f"[LOAD FAIL] ckpt={CKPT} device={device} err={e}")
108
- raise
109
 
110
- return _pipe["tok"], _pipe["model"], _pipe["device"]
111
 
112
 
113
  def translate_text(text: str):
114
  text = (text or "").strip()
115
  if not text:
116
  return ""
 
117
  if not model_ready():
118
  return "[Model not ready — checkpoint folder not found.]"
119
 
120
- tok, model, device = get_model()
121
-
122
- prompt = PREFIX + text
123
- inputs = tok(prompt, return_tensors="pt", truncation=True, max_length=256)
124
- inputs = {k: v.to(device) for k, v in inputs.items()}
125
 
126
- input_len = inputs["input_ids"].shape[1]
127
- max_new = int(input_len * 1.3)
128
- max_new = min(max_new, 40)
129
 
130
- gen_kwargs = dict(
131
- num_beams=2,
132
- length_penalty=1.15,
133
- no_repeat_ngram_size=4,
134
- repetition_penalty=1.30,
135
- early_stopping=True,
136
- min_new_tokens=4,
137
- max_new_tokens=max_new,
138
- )
139
 
140
- # only pass eos_token_id if we found one
141
- if getattr(tok, "eos_token_id", None) is not None:
142
- gen_kwargs["eos_token_id"] = tok.eos_token_id
 
 
 
 
 
 
143
 
144
- with torch.no_grad():
145
- out_ids = model.generate(**inputs, **gen_kwargs)
146
 
147
- result = tok.decode(out_ids[0], skip_special_tokens=True).strip()
 
148
 
149
- if not result:
150
- raw = tok.decode(out_ids[0])
151
- return f"[Blank output] raw={raw!r} ckpt={_pipe['ckpt']} device={device}"
152
 
153
- return result
 
154
 
155
 
156
- title = "Translator Demo"
157
- status = "✅ Model found" if model_ready() else "⏳ Model not found"
158
-
159
  with gr.Blocks() as demo:
160
- gr.Markdown(f"# {title}")
161
- gr.Markdown(f"**Status:** {status}")
162
  gr.Markdown(f"**Model root:** `{MODEL_ROOT}`")
163
- gr.Markdown(f"**Loaded checkpoint:** `{str(CKPT) if CKPT else 'None'}`")
164
  gr.Markdown(f"**Prefix:** `{PREFIX}`")
165
 
166
- # ✅ EOS note + image
167
  gr.Markdown("### ⚠️ Note: EOS wasn’t properly learnt by the model")
168
- gr.Image(
169
- value=str(ROOT / "eos_ack.png"),
170
- show_label=False,
171
- interactive=False,
172
- height=260,
173
- )
174
-
175
- inp = gr.Textbox(label="Input", lines=4, placeholder="Type here...")
 
 
 
176
  out = gr.Textbox(label="Output", lines=4)
177
-
178
- btn = gr.Button("Translate")
179
- btn.click(translate_text, inp, out)
180
 
181
  demo.launch()
 
6
 
7
  ROOT = Path(__file__).parent
8
 
9
+ # ✅ match your actual repo folder name (case-sensitive on Spaces)
10
  MODEL_ROOT = ROOT / "vMAX_512_8L_32k_OpenSubs"
11
+
12
+ # ✅ adjust this to what your model was trained on
13
  PREFIX = "translate English to Indonesian: "
14
 
15
 
 
22
  ckpts.sort(key=lambda p: int(p.name.split("-")[-1]))
23
  return ckpts[-1]
24
 
 
25
  if (root / "config.json").exists() or (root / "model.safetensors").exists():
26
  return root
27
 
 
29
 
30
 
31
  CKPT = latest_checkpoint(MODEL_ROOT)
32
+ print("[BOOT] MODEL_ROOT =", MODEL_ROOT)
33
+ print("[BOOT] CKPT =", CKPT)
34
+
35
  _pipe = {"tok": None, "model": None, "device": None, "ckpt": None}
36
 
37
 
 
41
 
42
  class SimpleTokenizerWrapper:
43
  """
44
+ Wrapper over tokenizers.Tokenizer so we avoid the TokenizersBackend crash on Spaces.
45
+ Needs: __call__ -> tensors, decode(), eos_token_id (optional)
 
 
46
  """
47
  def __init__(self, base_tok: Tokenizer):
48
  self.base_tok = base_tok
 
 
49
  self.eos_token_id = self._find_id(["</s>", "<eos>", "[EOS]"])
50
  self.pad_token_id = self._find_id(["<pad>", "[PAD]"])
51
 
 
61
 
62
  def __call__(self, text, return_tensors="pt", truncation=True, max_length=256):
63
  enc = self.base_tok.encode(text)
64
+ ids = enc.ids[:max_length] if (truncation and max_length is not None) else enc.ids
 
 
 
 
65
  input_ids = torch.tensor([ids], dtype=torch.long)
66
  attention_mask = torch.ones_like(input_ids)
 
67
  return {"input_ids": input_ids, "attention_mask": attention_mask}
68
 
69
  def decode(self, ids, skip_special_tokens=True):
 
80
  return SimpleTokenizerWrapper(base)
81
 
82
 
83
+ def load_eos_meme_b64() -> str:
84
+ b64_path = ROOT / "eos_ack.b64"
85
+ if not b64_path.exists():
86
+ return ""
87
+ return b64_path.read_text(encoding="utf-8").strip()
88
+
89
+
90
+ EOS_ACK_B64 = load_eos_meme_b64()
91
+
92
+
93
  def get_model():
94
+ if _pipe["model"] is not None:
95
+ return _pipe["tok"], _pipe["model"], _pipe["device"]
 
96
 
97
+ if CKPT is None:
98
+ raise RuntimeError(f"No checkpoint found under: {MODEL_ROOT}")
99
 
100
+ device = "cuda" if torch.cuda.is_available() else "cpu"
101
+ print("[LOAD] device =", device)
102
+ print("[LOAD] loading from =", CKPT)
103
 
104
+ tok = load_tokenizer(CKPT)
 
105
 
106
+ # low-memory load
107
+ kwargs = dict(low_cpu_mem_usage=True)
108
+ if device == "cuda":
109
+ kwargs["torch_dtype"] = torch.float16 # comment out if you don’t want fp16
110
 
111
+ model = AutoModelForSeq2SeqLM.from_pretrained(CKPT, **kwargs).to(device)
112
+ model.eval()
113
+
114
+ _pipe.update({"tok": tok, "model": model, "device": device, "ckpt": str(CKPT)})
115
+ print("[LOAD] OK")
 
 
116
 
117
+ return tok, model, device
118
 
119
 
120
  def translate_text(text: str):
121
  text = (text or "").strip()
122
  if not text:
123
  return ""
124
+
125
  if not model_ready():
126
  return "[Model not ready — checkpoint folder not found.]"
127
 
128
+ try:
129
+ tok, model, device = get_model()
 
 
 
130
 
131
+ prompt = PREFIX + text
132
+ inputs = tok(prompt, return_tensors="pt", truncation=True, max_length=256)
133
+ inputs = {k: v.to(device) for k, v in inputs.items()}
134
 
135
+ input_len = inputs["input_ids"].shape[1]
136
+ max_new = min(int(input_len * 1.3), 40)
 
 
 
 
 
 
 
137
 
138
+ gen_kwargs = dict(
139
+ num_beams=2,
140
+ length_penalty=1.15,
141
+ no_repeat_ngram_size=4,
142
+ repetition_penalty=1.30,
143
+ early_stopping=True,
144
+ min_new_tokens=4,
145
+ max_new_tokens=max_new,
146
+ )
147
 
148
+ if getattr(tok, "eos_token_id", None) is not None:
149
+ gen_kwargs["eos_token_id"] = tok.eos_token_id
150
 
151
+ with torch.no_grad():
152
+ out_ids = model.generate(**inputs, **gen_kwargs)
153
 
154
+ return tok.decode(out_ids[0], skip_special_tokens=True).strip()
 
 
155
 
156
+ except Exception as e:
157
+ return f"[ERROR] {type(e).__name__}: {e}"
158
 
159
 
 
 
 
160
  with gr.Blocks() as demo:
161
+ gr.Markdown("# Translator Demo")
162
+ gr.Markdown(f"**Status:** {'✅ Model found' if model_ready() else '⏳ Model not found'}")
163
  gr.Markdown(f"**Model root:** `{MODEL_ROOT}`")
164
+ gr.Markdown(f"**Checkpoint:** `{str(CKPT) if CKPT else 'None'}`")
165
  gr.Markdown(f"**Prefix:** `{PREFIX}`")
166
 
167
+ # ✅ EOS note + embedded meme (no png file, so HF won’t block the push)
168
  gr.Markdown("### ⚠️ Note: EOS wasn’t properly learnt by the model")
169
+ if EOS_ACK_B64:
170
+ gr.HTML(f"""
171
+ <div style="max-width:520px;">
172
+ <img src="data:image/png;base64,{EOS_ACK_B64}"
173
+ style="width:100%; border-radius:12px;" />
174
+ </div>
175
+ """)
176
+ else:
177
+ gr.Markdown("_Missing `eos_ack.b64` — meme not loaded._")
178
+
179
+ inp = gr.Textbox(label="Input", lines=4)
180
  out = gr.Textbox(label="Output", lines=4)
181
+ gr.Button("Translate").click(translate_text, inp, out)
 
 
182
 
183
  demo.launch()