metythorn commited on
Commit
a9e2711
·
1 Parent(s): 8491e67

add sample for testing

Browse files
Files changed (6) hide show
  1. app.py +54 -139
  2. image.png +0 -0
  3. image1.png +0 -0
  4. image2.png +0 -0
  5. image3.png +0 -0
  6. image4.png +0 -0
app.py CHANGED
@@ -1,13 +1,16 @@
1
  # app.py
2
- # Minimal Gradio app:
3
- # - User uploads an image
4
- # - App loads your private HF repo (best.pt + vocab_char.json) using HF_TOKEN secret
5
- # - Returns OCR text
6
  #
7
- # Hugging Face Space:
8
- # 1) Add a Space secret named HF_TOKEN (Settings → Secrets)
9
- # 2) Make sure your private model repo contains: best.pt, vocab_char.json
10
- # 3) requirements.txt should include python-multipart>=0.0.9
 
 
 
 
11
 
12
  import os
13
  import math
@@ -31,6 +34,14 @@ REPO_ID = "SoyVitou/infinity-khmer-ocr-large"
31
  CKPT_FILENAME = "best.pt"
32
  VOCAB_FILENAME = "vocab_char.json"
33
 
 
 
 
 
 
 
 
 
34
 
35
  @dataclass
36
  class CFG:
@@ -75,7 +86,6 @@ class CFG:
75
  UNK_LOGP_PENALTY: float = 1.0
76
 
77
  EMA_DECAY: float = 0.999
78
- VOCAB_JSON: str = VOCAB_FILENAME
79
 
80
  USE_FP16: bool = True
81
  USE_AUTOCAST: bool = True
@@ -83,13 +93,7 @@ class CFG:
83
 
84
 
85
  class CharTokenizer:
86
- def __init__(
87
- self,
88
- vocab_json: str,
89
- unk_token: str = "<unk>",
90
- collapse_whitespace: bool = True,
91
- unicode_nfc: bool = True
92
- ):
93
  with open(vocab_json, "r", encoding="utf-8") as f:
94
  vocab_raw: Dict[str, int] = json.load(f)
95
 
@@ -105,8 +109,6 @@ class CharTokenizer:
105
 
106
  self.unk_token = unk_token
107
  self.unk_id = self.token_to_id[unk_token]
108
- self.collapse_whitespace = collapse_whitespace
109
- self.unicode_nfc = unicode_nfc
110
 
111
  self.blank_id = 0
112
  self.pad_id = 1
@@ -218,8 +220,6 @@ class ConvStem(nn.Module):
218
  class HybridContextOCRV2(nn.Module):
219
  def __init__(self, cfg: CFG, tok: CharTokenizer):
220
  super().__init__()
221
- self.cfg = cfg
222
- self.tok = tok
223
  d = cfg.DROPOUT
224
 
225
  self.stem = ConvStem(cfg.ENC_DIM, d)
@@ -238,35 +238,29 @@ class HybridContextOCRV2(nn.Module):
238
  self.enc = nn.TransformerEncoder(enc_layer, num_layers=cfg.ENC_LAYERS)
239
  self.enc_ln = nn.LayerNorm(cfg.ENC_DIM)
240
 
241
- self.use_ctc = cfg.USE_CTC
242
- if self.use_ctc:
243
- self.ctc_head = nn.Sequential(
244
- nn.LayerNorm(cfg.ENC_DIM),
245
- nn.Dropout(d),
246
- nn.Linear(cfg.ENC_DIM, tok.ctc_classes),
247
- )
248
-
249
- self.use_decoder = cfg.USE_DECODER
250
- if self.use_decoder:
251
- self.mem_proj = nn.Linear(cfg.ENC_DIM, cfg.DEC_DIM, bias=False)
252
- self.dec_emb = nn.Embedding(tok.dec_vocab, cfg.DEC_DIM)
253
-
254
- dec_layer = nn.TransformerDecoderLayer(
255
- d_model=cfg.DEC_DIM,
256
- nhead=cfg.DEC_HEADS,
257
- dim_feedforward=cfg.DEC_FF,
258
- dropout=d,
259
- batch_first=True,
260
- activation="gelu",
261
- norm_first=True,
262
- )
263
- self.dec = nn.TransformerDecoder(dec_layer, num_layers=cfg.DEC_LAYERS)
264
- self.dec_ln = nn.LayerNorm(cfg.DEC_DIM)
265
- self.dec_head = nn.Linear(cfg.DEC_DIM, tok.dec_vocab)
266
-
267
- self.use_lm = cfg.USE_LM
268
- if self.use_lm:
269
- self.lm_head = nn.Linear(cfg.DEC_DIM, tok.dec_vocab)
270
 
271
  def encode(self, imgs: torch.Tensor) -> torch.Tensor:
272
  x = self.stem(imgs)
@@ -280,8 +274,7 @@ class HybridContextOCRV2(nn.Module):
280
 
281
 
282
  class EMA:
283
- def __init__(self, model: nn.Module, decay: float):
284
- self.decay = decay
285
  self.shadow = {k: v.detach().clone() for k, v in model.state_dict().items()}
286
 
287
  @torch.no_grad()
@@ -330,9 +323,6 @@ def beam_decode_one_batched(
330
  max_steps = min(max_steps, max(1, int(target_len * cfg.DEC_MAX_LEN_RATIO) + cfg.DEC_MAX_LEN_PAD))
331
  else:
332
  max_steps = min(max_steps, cfg.DEC_MAX_LEN_PAD)
333
- else:
334
- mem_len = mem_proj_1.size(1)
335
- max_steps = min(max_steps, max(1, int(mem_len * cfg.MEM_MAX_LEN_RATIO) + cfg.DEC_MAX_LEN_PAD))
336
 
337
  full_causal = torch.triu(
338
  torch.ones((cfg.MAX_DEC_LEN + 2, cfg.MAX_DEC_LEN + 2), device=device, dtype=torch.bool),
@@ -368,28 +358,15 @@ def beam_decode_one_batched(
368
  logits = model.dec_head(out)[:, -1, :]
369
  logp = F.log_softmax(logits, dim=-1)
370
 
371
- if cfg.USE_LM and cfg.USE_LM_FUSION_EVAL:
372
- lm_logits = model.lm_head(out)[:, -1, :]
373
- logp = logp + cfg.LM_FUSION_ALPHA * F.log_softmax(lm_logits, dim=-1)
374
 
375
  unk_id = tok.unk_id + tok.dec_offset
376
 
377
  for i, (_, seq, _) in enumerate(alive):
378
  cur_len = len(seq) - 1
379
-
380
- if target_len is not None and target_len > 0:
381
- min_len = min(cfg.EOS_BIAS_UNTIL_LEN, max(1, int(target_len * 0.7)))
382
- if cur_len < min_len:
383
- logp[i, tok.dec_eos] = logp[i, tok.dec_eos] - cfg.EOS_LOGP_BIAS
384
- elif cur_len >= target_len:
385
- logp[i, tok.dec_eos] = logp[i, tok.dec_eos] + cfg.EOS_LOGP_BOOST
386
- else:
387
- if cur_len < cfg.EOS_BIAS_UNTIL_LEN:
388
- logp[i, tok.dec_eos] = logp[i, tok.dec_eos] - cfg.EOS_LOGP_BIAS
389
-
390
- if len(seq) >= 4 and seq[-1] == seq[-2] == seq[-3]:
391
- logp[i, seq[-1]] = logp[i, seq[-1]] - cfg.REPEAT_LAST_PENALTY
392
-
393
  logp[i, unk_id] = logp[i, unk_id] - cfg.UNK_LOGP_PENALTY
394
 
395
  topv, topi = torch.topk(logp, k=cfg.BEAM, dim=-1)
@@ -410,65 +387,7 @@ def beam_decode_one_batched(
410
  new_beams.sort(key=lambda x: normed(x[0], x[1]), reverse=True)
411
  beams = new_beams[:cfg.BEAM]
412
 
413
- def length_norm(score: float, seq: List[int]) -> float:
414
- return score / (max(1, len(seq) - 1) ** cfg.BEAM_LENP)
415
-
416
- if ctc_logits_1 is not None and cfg.CTC_FUSION_ALPHA > 0:
417
- log_probs = F.log_softmax(ctc_logits_1.squeeze(0), dim=-1)
418
-
419
- def ctc_sequence_log_prob(label_ids: List[int]) -> torch.Tensor:
420
- if len(label_ids) == 0:
421
- return log_probs[:, tok.blank_id].sum()
422
-
423
- blank = tok.blank_id
424
- ext = [blank]
425
- for lid in label_ids:
426
- ext.append(lid)
427
- ext.append(blank)
428
-
429
- s_len = len(ext)
430
- alpha = log_probs.new_full((s_len,), float("-inf"))
431
- alpha[0] = log_probs[0, blank]
432
- alpha[1] = log_probs[0, ext[1]]
433
-
434
- for t in range(1, log_probs.size(0)):
435
- next_alpha = log_probs.new_full((s_len,), float("-inf"))
436
- for s in range(s_len):
437
- candidates = [alpha[s]]
438
- if s - 1 >= 0:
439
- candidates.append(alpha[s - 1])
440
- if s - 2 >= 0 and ext[s] != blank and ext[s] != ext[s - 2]:
441
- candidates.append(alpha[s - 2])
442
- next_alpha[s] = torch.logsumexp(torch.stack(candidates), dim=0) + log_probs[t, ext[s]]
443
- alpha = next_alpha
444
-
445
- if s_len == 1:
446
- return alpha[0]
447
- return torch.logsumexp(torch.stack([alpha[s_len - 1], alpha[s_len - 2]]), dim=0)
448
-
449
- def seq_to_ctc_labels(seq: List[int]) -> List[int]:
450
- labels = []
451
- for x in seq[1:]:
452
- if x == tok.dec_eos:
453
- break
454
- if x in (tok.dec_pad, tok.dec_bos):
455
- continue
456
- y = x - tok.dec_offset
457
- if 0 <= y < tok.vocab_size:
458
- labels.append(y + tok.ctc_offset)
459
- else:
460
- labels.append(tok.unk_id + tok.ctc_offset)
461
- return labels
462
-
463
- def combined_score(entry):
464
- dec_score = length_norm(entry[0], entry[1])
465
- labels = seq_to_ctc_labels(entry[1])
466
- ctc_score = ctc_sequence_log_prob(labels) / max(1, len(labels))
467
- return dec_score + cfg.CTC_FUSION_ALPHA * float(ctc_score)
468
-
469
- best = max(beams, key=combined_score)[1]
470
- else:
471
- best = max(beams, key=lambda x: length_norm(x[0], x[1]))[1]
472
 
473
  ids = []
474
  for x in best[1:]:
@@ -515,12 +434,7 @@ def load_model():
515
  if hasattr(cfg, k):
516
  setattr(cfg, k, v)
517
 
518
- tok = CharTokenizer(
519
- vocab_path,
520
- unk_token=cfg.UNK_TOKEN,
521
- collapse_whitespace=cfg.COLLAPSE_WHITESPACE,
522
- unicode_nfc=cfg.UNICODE_NFC,
523
- )
524
 
525
  device = setup_device(cfg)
526
 
@@ -528,7 +442,7 @@ def load_model():
528
  model.load_state_dict(ckpt["model"], strict=True)
529
 
530
  if isinstance(ckpt, dict) and "ema" in ckpt and isinstance(ckpt["ema"], dict):
531
- ema = EMA(model, decay=cfg.EMA_DECAY)
532
  ema.shadow = {k: v.detach().clone() for k, v in ckpt["ema"].items()}
533
  ema.copy_to(model)
534
 
@@ -550,10 +464,9 @@ def predict(img: Image.Image) -> str:
550
  x = preprocess_pil(CFG_OBJ, img).to(DEVICE)
551
  if CFG_OBJ.USE_FP16 and DEVICE == "cuda":
552
  x = x.half()
553
-
554
  mem = MODEL.encode(x)
555
  mem_proj = MODEL.mem_proj(mem)
556
- ctc_logits = MODEL.ctc_head(mem) if CFG_OBJ.USE_CTC else None
557
  return beam_decode_one_batched(MODEL, mem_proj, TOK, CFG_OBJ, ctc_logits_1=ctc_logits)
558
 
559
 
@@ -562,6 +475,8 @@ demo = gr.Interface(
562
  inputs=gr.Image(type="pil", label="Upload image"),
563
  outputs=gr.Textbox(label="OCR result", lines=6),
564
  title="Infinity Khmer OCR",
 
 
565
  )
566
 
567
  if __name__ == "__main__":
 
1
  # app.py
2
+ # Minimal Gradio app with Examples:
3
+ # - Loads your private HF model repo using HF_TOKEN (Space secret)
4
+ # - User can upload an image OR click an example image to test quickly
 
5
  #
6
+ # Put sample images in the Space repo root (same folder as app.py):
7
+ # image.png, image1.png, image2.png, image3.png, image4.png
8
+ #
9
+ # Space Secrets:
10
+ # HF_TOKEN = <your HF access token with access to the private model repo>
11
+ #
12
+ # Private model repo must contain:
13
+ # best.pt, vocab_char.json
14
 
15
  import os
16
  import math
 
34
  CKPT_FILENAME = "best.pt"
35
  VOCAB_FILENAME = "vocab_char.json"
36
 
37
+ EXAMPLES = [
38
+ ["./image.png"],
39
+ ["./image1.png"],
40
+ ["./image2.png"],
41
+ ["./image3.png"],
42
+ ["./image4.png"],
43
+ ]
44
+
45
 
46
  @dataclass
47
  class CFG:
 
86
  UNK_LOGP_PENALTY: float = 1.0
87
 
88
  EMA_DECAY: float = 0.999
 
89
 
90
  USE_FP16: bool = True
91
  USE_AUTOCAST: bool = True
 
93
 
94
 
95
  class CharTokenizer:
96
+ def __init__(self, vocab_json: str, unk_token: str = "<unk>", collapse_whitespace: bool = True, unicode_nfc: bool = True):
 
 
 
 
 
 
97
  with open(vocab_json, "r", encoding="utf-8") as f:
98
  vocab_raw: Dict[str, int] = json.load(f)
99
 
 
109
 
110
  self.unk_token = unk_token
111
  self.unk_id = self.token_to_id[unk_token]
 
 
112
 
113
  self.blank_id = 0
114
  self.pad_id = 1
 
220
  class HybridContextOCRV2(nn.Module):
221
  def __init__(self, cfg: CFG, tok: CharTokenizer):
222
  super().__init__()
 
 
223
  d = cfg.DROPOUT
224
 
225
  self.stem = ConvStem(cfg.ENC_DIM, d)
 
238
  self.enc = nn.TransformerEncoder(enc_layer, num_layers=cfg.ENC_LAYERS)
239
  self.enc_ln = nn.LayerNorm(cfg.ENC_DIM)
240
 
241
+ self.ctc_head = nn.Sequential(
242
+ nn.LayerNorm(cfg.ENC_DIM),
243
+ nn.Dropout(d),
244
+ nn.Linear(cfg.ENC_DIM, tok.ctc_classes),
245
+ )
246
+
247
+ self.mem_proj = nn.Linear(cfg.ENC_DIM, cfg.DEC_DIM, bias=False)
248
+ self.dec_emb = nn.Embedding(tok.dec_vocab, cfg.DEC_DIM)
249
+
250
+ dec_layer = nn.TransformerDecoderLayer(
251
+ d_model=cfg.DEC_DIM,
252
+ nhead=cfg.DEC_HEADS,
253
+ dim_feedforward=cfg.DEC_FF,
254
+ dropout=d,
255
+ batch_first=True,
256
+ activation="gelu",
257
+ norm_first=True,
258
+ )
259
+ self.dec = nn.TransformerDecoder(dec_layer, num_layers=cfg.DEC_LAYERS)
260
+ self.dec_ln = nn.LayerNorm(cfg.DEC_DIM)
261
+ self.dec_head = nn.Linear(cfg.DEC_DIM, tok.dec_vocab)
262
+
263
+ self.lm_head = nn.Linear(cfg.DEC_DIM, tok.dec_vocab)
 
 
 
 
 
 
264
 
265
  def encode(self, imgs: torch.Tensor) -> torch.Tensor:
266
  x = self.stem(imgs)
 
274
 
275
 
276
  class EMA:
277
+ def __init__(self, model: nn.Module):
 
278
  self.shadow = {k: v.detach().clone() for k, v in model.state_dict().items()}
279
 
280
  @torch.no_grad()
 
323
  max_steps = min(max_steps, max(1, int(target_len * cfg.DEC_MAX_LEN_RATIO) + cfg.DEC_MAX_LEN_PAD))
324
  else:
325
  max_steps = min(max_steps, cfg.DEC_MAX_LEN_PAD)
 
 
 
326
 
327
  full_causal = torch.triu(
328
  torch.ones((cfg.MAX_DEC_LEN + 2, cfg.MAX_DEC_LEN + 2), device=device, dtype=torch.bool),
 
358
  logits = model.dec_head(out)[:, -1, :]
359
  logp = F.log_softmax(logits, dim=-1)
360
 
361
+ lm_logits = model.lm_head(out)[:, -1, :]
362
+ logp = logp + cfg.LM_FUSION_ALPHA * F.log_softmax(lm_logits, dim=-1)
 
363
 
364
  unk_id = tok.unk_id + tok.dec_offset
365
 
366
  for i, (_, seq, _) in enumerate(alive):
367
  cur_len = len(seq) - 1
368
+ if cur_len < cfg.EOS_BIAS_UNTIL_LEN:
369
+ logp[i, tok.dec_eos] = logp[i, tok.dec_eos] - cfg.EOS_LOGP_BIAS
 
 
 
 
 
 
 
 
 
 
 
 
370
  logp[i, unk_id] = logp[i, unk_id] - cfg.UNK_LOGP_PENALTY
371
 
372
  topv, topi = torch.topk(logp, k=cfg.BEAM, dim=-1)
 
387
  new_beams.sort(key=lambda x: normed(x[0], x[1]), reverse=True)
388
  beams = new_beams[:cfg.BEAM]
389
 
390
+ best = max(beams, key=lambda x: x[0])[1]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
391
 
392
  ids = []
393
  for x in best[1:]:
 
434
  if hasattr(cfg, k):
435
  setattr(cfg, k, v)
436
 
437
+ tok = CharTokenizer(vocab_path, unk_token=cfg.UNK_TOKEN)
 
 
 
 
 
438
 
439
  device = setup_device(cfg)
440
 
 
442
  model.load_state_dict(ckpt["model"], strict=True)
443
 
444
  if isinstance(ckpt, dict) and "ema" in ckpt and isinstance(ckpt["ema"], dict):
445
+ ema = EMA(model)
446
  ema.shadow = {k: v.detach().clone() for k, v in ckpt["ema"].items()}
447
  ema.copy_to(model)
448
 
 
464
  x = preprocess_pil(CFG_OBJ, img).to(DEVICE)
465
  if CFG_OBJ.USE_FP16 and DEVICE == "cuda":
466
  x = x.half()
 
467
  mem = MODEL.encode(x)
468
  mem_proj = MODEL.mem_proj(mem)
469
+ ctc_logits = MODEL.ctc_head(mem)
470
  return beam_decode_one_batched(MODEL, mem_proj, TOK, CFG_OBJ, ctc_logits_1=ctc_logits)
471
 
472
 
 
475
  inputs=gr.Image(type="pil", label="Upload image"),
476
  outputs=gr.Textbox(label="OCR result", lines=6),
477
  title="Infinity Khmer OCR",
478
+ examples=EXAMPLES,
479
+ cache_examples=False,
480
  )
481
 
482
  if __name__ == "__main__":
image.png ADDED
image1.png ADDED
image2.png ADDED
image3.png ADDED
image4.png ADDED