Kimang18 commited on
Commit
f2188a9
·
verified ·
1 Parent(s): babb162

Use tror-yong-ocr package

Browse files
Files changed (1) hide show
  1. model.py +10 -194
model.py CHANGED
@@ -1,206 +1,22 @@
1
- from typing import Sequence
2
- import torch
3
- import torch.nn as nn
4
- import torch.nn.functional as F
5
- from timm.models.vision_transformer import PatchEmbed, VisionTransformer
6
- from dataclasses import dataclass
7
- from torch import Tensor
8
- import math
9
-
10
-
11
- class CharTokenizer:
12
- def __init__(self, chars, special_tokens=['<s>', '</s>', '<pad>', '<unk>']):
13
- self.special_tokens = special_tokens
14
- # Unique characters + special tokens
15
- self.vocab = tuple(special_tokens[:1]) + tuple(chars) + tuple(special_tokens[1:])
16
- self.str_to_int = {s: i for i, s in enumerate(self.vocab)}
17
- self.int_to_str = {i: s for i, s in enumerate(self.vocab)}
18
- self.bos_id = self.str_to_int['<s>']
19
- self.eos_id = self.str_to_int['</s>']
20
- self.pad_id = self.str_to_int['<pad>']
21
- self.unk_id = self.str_to_int['<unk>']
22
-
23
- def __len__(self):
24
- return len(self.vocab)
25
-
26
- def encode(self, text, add_special_tokens=False):
27
- tokens = []
28
- i = 0
29
- while i < len(text):
30
- matched_special = False
31
- # Check for existing special tokens in the input string
32
- for spec in self.special_tokens:
33
- if text.startswith(spec, i):
34
- tokens.append(self.str_to_int[spec])
35
- i += len(spec)
36
- matched_special = True
37
- break
38
-
39
- if not matched_special:
40
- char = text[i]
41
- tokens.append(self.str_to_int.get(char, self.str_to_int['<unk>']))
42
- i += 1
43
-
44
- # Wrap with <s> and </s> if requested
45
- if add_special_tokens:
46
- tokens = [self.str_to_int['<s>']] + tokens + [self.str_to_int['</s>']]
47
-
48
- return tokens
49
-
50
- def decode(self, ids, ignore_special_tokens=False):
51
- if ignore_special_tokens:
52
- # Filter out any ID that belongs to the special_tokens list
53
- return "".join([self.int_to_str[i] for i in ids if self.int_to_str[i] not in self.special_tokens])
54
-
55
- return "".join([self.int_to_str.get(i, '<unk>') for i in ids])
56
-
57
-
58
- class ImageEncoder(VisionTransformer):
59
- def __init__(self, config):
60
- super().__init__(
61
- img_size=config.img_size,
62
- patch_size=config.patch_size,
63
- in_chans=config.n_channel,
64
- embed_dim=config.n_embed,
65
- depth=config.n_layer,
66
- num_heads=config.n_head,
67
- mlp_ratio=4,
68
- qkv_bias=True,
69
- drop_rate=0.0,
70
- attn_drop_rate=0.0,
71
- drop_path_rate=0.0,
72
- embed_layer=PatchEmbed,
73
- num_classes=0, # These
74
- global_pool='', # disable the
75
- class_token=False, # classifier head.
76
- )
77
-
78
- def forward(self, x):
79
- return self.forward_features(x)
80
-
81
-
82
- class RMSNorm(nn.RMSNorm):
83
- def forward(self, x):
84
- return super().forward(x.float()).type(x.dtype)
85
-
86
-
87
- class Linear(nn.Linear):
88
- def forward(self, x: Tensor) -> Tensor:
89
- return F.linear(x, self.weight.to(x.dtype), None if self.bias is None else self.bias.to(x.dtype))
90
-
91
-
92
- class TextDecoder(nn.Module):
93
- def __init__(self, config, ) -> None:
94
- super().__init__()
95
- self.config = config
96
- self.n_head = 2 * config.n_head
97
- self.tok_embed = nn.Embedding(config.vocab_size, config.n_embed)
98
- self.pos_embed = nn.Parameter(torch.Tensor(
99
- 1, config.block_size, config.n_embed))
100
- self.dropout = nn.Dropout(config.dropout)
101
-
102
- self.sa_ln = RMSNorm(config.n_embed)
103
- self.sa_attn = nn.MultiheadAttention(config.n_embed, self.n_head, dropout=config.dropout, batch_first=True)
104
-
105
- self.cross_ln = RMSNorm(config.n_embed)
106
- self.cross_attn = nn.MultiheadAttention(config.n_embed, self.n_head, dropout=config.dropout, batch_first=True)
107
-
108
- self.ffn_ln = RMSNorm(config.n_embed)
109
- dim_feedforward = 4*config.n_embed
110
- self.ffn = nn.Sequential(
111
- Linear(config.n_embed, dim_feedforward, bias=config.bias),
112
- nn.GELU(),
113
- Linear(dim_feedforward, config.n_embed, bias=config.bias),
114
- nn.Dropout(config.dropout)
115
- )
116
- self.lm_head = Linear(config.n_embed, config.vocab_size)
117
- nn.init.trunc_normal_(self.pos_embed, std=0.02)
118
-
119
- def forward(self, x: Tensor, xi: Tensor):
120
- """
121
- x: input token ids
122
- xi: image features (already normalized by ImageEncoder)
123
- """
124
- b, t = x.size()
125
- tok_embed = self.tok_embed(x) * math.sqrt(self.config.n_embed)
126
-
127
- ctx = torch.cat(
128
- [tok_embed[:, :1], self.pos_embed[:, :t-1] + tok_embed[:, 1:]], dim=1)
129
- ctx = self.dropout(ctx)
130
- ctx = self.sa_ln(ctx)
131
- res = self.dropout(self.pos_embed[:, :t].expand(b, -1, -1)) # (b, t, n_embed)
132
-
133
- mask = torch.triu(torch.ones((t, t), dtype=torch.bool, device=x.device), 1)
134
- query, sa_weights = self.sa_attn(self.sa_ln(res), ctx, ctx, attn_mask=mask)
135
- res = res + query
136
- query, ca_weights = self.cross_attn(self.cross_ln(res), xi, xi)
137
- res = res + query
138
- res = res + self.ffn(self.ffn_ln(res))
139
- return self.lm_head(res[:, [-1], :]).float()
140
-
141
-
142
- class OCRModel(nn.Module):
143
- def __init__(self, config, tokenizer) -> None:
144
- super().__init__()
145
- self.encoder = ImageEncoder(config)
146
- self.decoder = TextDecoder(config)
147
- self.tokenizer = tokenizer
148
-
149
- def forward(self, img_tensor: Tensor, input_tokens: Tensor):
150
- xi = self.encoder(img_tensor)
151
- logits, loss = self.decoder(input_tokens, xi)
152
- return logits, loss
153
-
154
- @torch.inference_mode()
155
- def generate(self, img_tensor: Tensor, max_new_tokens: int, temperature=1.0, top_k=None):
156
- xi = self.encoder(img_tensor.unsqueeze(0))
157
- idx = torch.full((xi.size(0),1), fill_value=self.tokenizer.bos_id, dtype=torch.long, device=img_tensor.device)
158
- for i in range(max_new_tokens):
159
- logits = self.decoder(idx, xi)
160
- logits = logits[:, -1, :] / temperature
161
- if top_k is not None:
162
- v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
163
- logits[logits < v[:, [-1]]] = -float('inf')
164
- probs = F.softmax(logits, dim=-1)
165
- idx_next = torch.multinomial(probs, num_samples=1)
166
- idx = torch.cat((idx, idx_next), dim=1)
167
- if idx_next.item() == self.tokenizer.eos_id:
168
- break
169
- return self.tokenizer.decode(idx[0].tolist(), ignore_special_tokens=True)
170
-
171
-
172
- @dataclass
173
- class ModelConfig:
174
- img_size: Sequence[int]
175
- patch_size: Sequence[int]
176
- n_channel: int
177
- vocab_size: int
178
- block_size: int
179
- n_layer: int
180
- n_head: int
181
- n_embed: int
182
- dropout: float = 0.0
183
- bias: bool = True
184
 
185
 
186
  def load_model():
187
- kh_charset = "០១២៣៤៥៦៧៨៩កខគឃងចឆជឈញដឋឌឍណតថទធនបផពភមយរលវសហឡអឥឧឳឪឱឫឬឭឮឦឰឯាិីឹឺុូួើឿៀេែៃោៅំះៈ់៉៊៍័៏៌្ ។៕៖ៗ"
188
- en_charset = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~"
189
- tokenizer = CharTokenizer(en_charset+kh_charset)
190
-
191
- config = ModelConfig(
192
  img_size=(32, 128),
193
  patch_size=(4, 8),
194
  n_channel=3,
195
- vocab_size=len(tokenizer),
196
  block_size=192,
197
- n_layer=12,
198
- n_head=3,
199
- n_embed=192,
200
  dropout=0.1,
201
  bias=True,
202
  )
203
- model = OCRModel(config, tokenizer)
204
- state_dict = torch.hub.load_state_dict_from_url('https://huggingface.co/KrorngAI/PARSeqForKhmer/resolve/main/parseq_kh.pt', map_location=torch.device('cpu'))
205
  model.load_state_dict(state_dict)
206
  return model
 
1
+ from tror_yong_ocr import TrorYongOCR, TrorYongConfig
2
+ from tror_yong_ocr import get_tokenizer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
 
5
  def load_model():
6
+ tokenizer = get_tokenizer()
7
+ config = TrorYongConfig(
 
 
 
8
  img_size=(32, 128),
9
  patch_size=(4, 8),
10
  n_channel=3,
11
+ vocab_size=len(tokenizer), # exclude pad and unk tokens
12
  block_size=192,
13
+ n_layer=4,
14
+ n_head=6,
15
+ n_embed=384,
16
  dropout=0.1,
17
  bias=True,
18
  )
19
+ model = TrorYongOCR(config, tokenizer)
20
+ state_dict = torch.hub.load_state_dict_from_url('https://huggingface.co/KrorngAI/PARSeqForKhmer/resolve/main/best_model-80epoch.pt', map_location=torch.device('cpu'))
21
  model.load_state_dict(state_dict)
22
  return model