Image-to-Text
Transformers
Safetensors
Khmer
khmer-ocr
feature-extraction
transformer
text-recognition
crnn
khmer-text-recognition
custom_code
Darayut commited on
Commit
f5bb5c5
·
verified ·
1 Parent(s): b525f82

Upload modeling_khmerocr.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_khmerocr.py +195 -0
modeling_khmerocr.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modeling_khmerocr.py
2
+ import torch
3
+ import torch.nn as nn
4
+ from torch.nn.utils.rnn import pad_sequence
5
+ from transformers import PreTrainedModel
6
+ from configuration_khmerocr import KhmerOCRConfig
7
+
8
+ # ==========================================
9
+ # 1. HELPER CLASSES (SequenceSE, CNN, etc.)
10
+ # ==========================================
11
+
12
+ class SequenceSE(nn.Module):
13
+ def __init__(self, channels, reduction=16):
14
+ super(SequenceSE, self).__init__()
15
+ self.fc = nn.Sequential(
16
+ nn.Conv1d(channels, channels // reduction, kernel_size=1),
17
+ nn.ReLU(inplace=True),
18
+ nn.Conv1d(channels // reduction, channels, kernel_size=1),
19
+ nn.Sigmoid()
20
+ )
21
+
22
+ def forward(self, x):
23
+ b, c, h, w = x.size()
24
+ y = torch.mean(x, dim=2).view(b, c, w)
25
+ y = self.fc(y)
26
+ y = y.view(b, c, 1, w)
27
+ return x * y
28
+
29
+ class ImprovedFeatureExtractor(nn.Module):
30
+ def __init__(self):
31
+ super().__init__()
32
+ self.conv1 = nn.Sequential(nn.Conv2d(1, 64, 3, 1, 1), nn.BatchNorm2d(64), nn.ReLU(True))
33
+ self.pool1 = nn.MaxPool2d(2, 2)
34
+ self.conv2 = nn.Sequential(nn.Conv2d(64, 128, 3, 1, 1), nn.BatchNorm2d(128), nn.ReLU(True))
35
+ self.pool2 = nn.MaxPool2d(2, 2)
36
+ self.conv3 = nn.Sequential(nn.Conv2d(128, 256, 3, 1, 1), nn.BatchNorm2d(256), nn.ReLU(True))
37
+ self.conv4 = nn.Sequential(nn.Conv2d(256, 256, 3, 1, 1), nn.BatchNorm2d(256), nn.ReLU(True))
38
+ self.se3 = SequenceSE(256)
39
+ self.pool3 = nn.MaxPool2d(kernel_size=(2, 1), stride=(2, 1))
40
+ self.conv5 = nn.Sequential(nn.Conv2d(256, 512, 3, 1, 1), nn.BatchNorm2d(512), nn.ReLU(True))
41
+ self.conv6 = nn.Sequential(nn.Conv2d(512, 512, 3, 1, 1), nn.BatchNorm2d(512), nn.ReLU(True))
42
+ self.se4 = SequenceSE(512)
43
+ self.pool4 = nn.MaxPool2d(kernel_size=(2, 1), stride=(2, 1))
44
+ self.conv7 = nn.Conv2d(512, 512, 3, 1, 1)
45
+ self.bn7 = nn.BatchNorm2d(512)
46
+ self.relu7 = nn.ReLU(True)
47
+ self.se5 = SequenceSE(512)
48
+ self.final_pool = nn.AdaptiveAvgPool2d((2, 32))
49
+
50
+ def forward(self, x):
51
+ x = self.pool1(self.conv1(x))
52
+ x = self.pool2(self.conv2(x))
53
+ x = self.conv4(self.conv3(x))
54
+ x = self.se3(x)
55
+ x = self.pool3(x)
56
+ x = self.conv6(self.conv5(x))
57
+ x = self.se4(x)
58
+ x = self.pool4(x)
59
+ x = self.relu7(self.bn7(self.conv7(x)))
60
+ x = self.se5(x)
61
+ x = self.final_pool(x)
62
+ return x
63
+
64
+ class PatchEncoder(nn.Module):
65
+ def __init__(self, in_channels, emb_dim, k1=2, k2=1, max_patches=256):
66
+ super().__init__()
67
+ self.proj = nn.Conv2d(in_channels, emb_dim, kernel_size=(k1, k2), stride=(k1, k2))
68
+ self.pos_emb = nn.Parameter(torch.zeros(max_patches, emb_dim))
69
+ nn.init.trunc_normal_(self.pos_emb, std=0.02)
70
+
71
+ def forward(self, F):
72
+ x = self.proj(F)
73
+ B, D, Hp, Wp = x.shape
74
+ N = Hp * Wp
75
+ x = x.flatten(2).transpose(1, 2)
76
+ x = x + self.pos_emb[:N].unsqueeze(0)
77
+ return x, N
78
+
79
+ def make_encoder(emb_dim=384, nhead=8, num_layers=3, dim_feedforward=1024, dropout=0.1):
80
+ enc_layer = nn.TransformerEncoderLayer(d_model=emb_dim, nhead=nhead,
81
+ dim_feedforward=dim_feedforward,
82
+ dropout=dropout, activation='relu')
83
+ return nn.TransformerEncoder(enc_layer, num_layers=num_layers)
84
+
85
+ class TransformerDecoderWrapper(nn.Module):
86
+ def __init__(self, vocab_size, emb_dim, nhead=8, num_layers=3, pad_idx=0, max_len=256):
87
+ super().__init__()
88
+ self.tok_emb = nn.Embedding(vocab_size, emb_dim, padding_idx=pad_idx)
89
+ dec_layer = nn.TransformerDecoderLayer(d_model=emb_dim, nhead=nhead, dim_feedforward=emb_dim*4, dropout=0.1)
90
+ self.decoder = nn.TransformerDecoder(dec_layer, num_layers=num_layers)
91
+ self.pos_emb = nn.Parameter(torch.zeros(max_len, emb_dim))
92
+ nn.init.trunc_normal_(self.pos_emb, std=0.1)
93
+ self.out_proj = nn.Linear(emb_dim, vocab_size)
94
+ self.pad_idx = pad_idx
95
+
96
+ def generate_square_subsequent_mask(self, sz):
97
+ mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
98
+ mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
99
+ return mask
100
+
101
+ def forward(self, tgt_tokens, memory, memory_key_padding_mask):
102
+ B, T = tgt_tokens.size()
103
+ device = tgt_tokens.device
104
+ tok = self.tok_emb(tgt_tokens)
105
+ pos = self.pos_emb[:T,:].unsqueeze(0).expand(B,-1,-1)
106
+ tgt = (tok + pos).transpose(0,1)
107
+ tgt_key_padding_mask = (tgt_tokens == self.pad_idx)
108
+ if memory_key_padding_mask is not None:
109
+ memory_key_padding_mask = memory_key_padding_mask.bool()
110
+ tgt_mask = self.generate_square_subsequent_mask(T).to(device)
111
+ mem = memory.transpose(0,1)
112
+ dec_out = self.decoder(tgt, mem, tgt_mask=tgt_mask, tgt_key_padding_mask=tgt_key_padding_mask, memory_key_padding_mask=memory_key_padding_mask)
113
+ return self.out_proj(dec_out.transpose(0,1))
114
+
115
+ # ==========================================
116
+ # 2. MAIN MODEL WRAPPER
117
+ # ==========================================
118
+ class KhmerOCR(PreTrainedModel):
119
+ config_class = KhmerOCRConfig
120
+
121
+ def __init__(self, config):
122
+ super().__init__(config)
123
+ self.vocab_size = config.vocab_size
124
+ self.pad_idx = config.pad_idx
125
+ self.emb_dim = config.emb_dim
126
+
127
+ self.cnn = ImprovedFeatureExtractor()
128
+ self.patch = PatchEncoder(512, emb_dim=self.emb_dim, k1=2, k2=1)
129
+ self.enc = make_encoder(emb_dim=self.emb_dim, nhead=config.nhead, num_layers=config.num_encoder_layers)
130
+
131
+ self.global_pos = nn.Parameter(torch.zeros(config.max_global_len, self.emb_dim))
132
+ nn.init.trunc_normal_(self.global_pos, std=0.02)
133
+
134
+ self.context_bilstm = nn.LSTM(
135
+ input_size=self.emb_dim,
136
+ hidden_size=self.emb_dim // 2,
137
+ num_layers=1,
138
+ batch_first=True,
139
+ bidirectional=True
140
+ )
141
+
142
+ self.dec = TransformerDecoderWrapper(self.vocab_size, emb_dim=self.emb_dim, nhead=config.nhead,
143
+ num_layers=config.num_decoder_layers, pad_idx=self.pad_idx)
144
+
145
+ def forward(self, chunk_lists, tgt_tokens=None):
146
+ # 1. Flatten
147
+ chunk_sizes = [len(c) for c in chunk_lists]
148
+ flat_input_list = [chunk for img_chunks in chunk_lists for chunk in img_chunks]
149
+ flat_input = torch.stack(flat_input_list)
150
+
151
+ # 2. Pipeline
152
+ f = self.cnn(flat_input)
153
+ p, _ = self.patch(f)
154
+ p = p.transpose(0, 1).contiguous()
155
+ enc_out = self.enc(p)
156
+ enc_out = enc_out.transpose(0, 1)
157
+
158
+ # 3. Merge
159
+ batch_encoded_list = []
160
+ cursor = 0
161
+ feature_dim = enc_out.size(-1)
162
+ for size in chunk_sizes:
163
+ img_chunks = enc_out[cursor : cursor + size]
164
+ merged_seq = img_chunks.reshape(-1, feature_dim)
165
+ batch_encoded_list.append(merged_seq)
166
+ cursor += size
167
+
168
+ # 4. Pad & Global Pos
169
+ memory = pad_sequence(batch_encoded_list, batch_first=True, padding_value=0.0)
170
+ B, T, _ = memory.shape
171
+ limit = min(T, self.global_pos.size(0))
172
+ pos_emb = self.global_pos[:limit, :].unsqueeze(0)
173
+
174
+ if T > self.global_pos.size(0):
175
+ memory = memory[:, :limit, :] + pos_emb
176
+ T = limit
177
+ else:
178
+ memory = memory + pos_emb
179
+
180
+ # 5. BiLSTM
181
+ self.context_bilstm.flatten_parameters()
182
+ memory, _ = self.context_bilstm(memory)
183
+
184
+ # If inference (no targets), return memory for search
185
+ if tgt_tokens is None:
186
+ return memory
187
+
188
+ # 6. Decoder
189
+ memory_key_padding_mask = torch.ones((B, T), dtype=torch.bool, device=memory.device)
190
+ for i, seq in enumerate(batch_encoded_list):
191
+ valid_len = min(seq.shape[0], T)
192
+ memory_key_padding_mask[i, :valid_len] = False
193
+
194
+ logits = self.dec(tgt_tokens, memory, memory_key_padding_mask)
195
+ return logits