Image-to-Text
Transformers
Safetensors
Khmer
khmer-ocr
feature-extraction
transformer
text-recognition
crnn
khmer-text-recognition
custom_code
khmer-text-recognition / modeling_khmerocr.py
Darayut's picture
Upload modeling_khmerocr.py with huggingface_hub
f5bb5c5 verified
# modeling_khmerocr.py
import torch
import torch.nn as nn
from torch.nn.utils.rnn import pad_sequence
from transformers import PreTrainedModel
from configuration_khmerocr import KhmerOCRConfig
# ==========================================
# 1. HELPER CLASSES (SequenceSE, CNN, etc.)
# ==========================================
class SequenceSE(nn.Module):
def __init__(self, channels, reduction=16):
super(SequenceSE, self).__init__()
self.fc = nn.Sequential(
nn.Conv1d(channels, channels // reduction, kernel_size=1),
nn.ReLU(inplace=True),
nn.Conv1d(channels // reduction, channels, kernel_size=1),
nn.Sigmoid()
)
def forward(self, x):
b, c, h, w = x.size()
y = torch.mean(x, dim=2).view(b, c, w)
y = self.fc(y)
y = y.view(b, c, 1, w)
return x * y
class ImprovedFeatureExtractor(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Sequential(nn.Conv2d(1, 64, 3, 1, 1), nn.BatchNorm2d(64), nn.ReLU(True))
self.pool1 = nn.MaxPool2d(2, 2)
self.conv2 = nn.Sequential(nn.Conv2d(64, 128, 3, 1, 1), nn.BatchNorm2d(128), nn.ReLU(True))
self.pool2 = nn.MaxPool2d(2, 2)
self.conv3 = nn.Sequential(nn.Conv2d(128, 256, 3, 1, 1), nn.BatchNorm2d(256), nn.ReLU(True))
self.conv4 = nn.Sequential(nn.Conv2d(256, 256, 3, 1, 1), nn.BatchNorm2d(256), nn.ReLU(True))
self.se3 = SequenceSE(256)
self.pool3 = nn.MaxPool2d(kernel_size=(2, 1), stride=(2, 1))
self.conv5 = nn.Sequential(nn.Conv2d(256, 512, 3, 1, 1), nn.BatchNorm2d(512), nn.ReLU(True))
self.conv6 = nn.Sequential(nn.Conv2d(512, 512, 3, 1, 1), nn.BatchNorm2d(512), nn.ReLU(True))
self.se4 = SequenceSE(512)
self.pool4 = nn.MaxPool2d(kernel_size=(2, 1), stride=(2, 1))
self.conv7 = nn.Conv2d(512, 512, 3, 1, 1)
self.bn7 = nn.BatchNorm2d(512)
self.relu7 = nn.ReLU(True)
self.se5 = SequenceSE(512)
self.final_pool = nn.AdaptiveAvgPool2d((2, 32))
def forward(self, x):
x = self.pool1(self.conv1(x))
x = self.pool2(self.conv2(x))
x = self.conv4(self.conv3(x))
x = self.se3(x)
x = self.pool3(x)
x = self.conv6(self.conv5(x))
x = self.se4(x)
x = self.pool4(x)
x = self.relu7(self.bn7(self.conv7(x)))
x = self.se5(x)
x = self.final_pool(x)
return x
class PatchEncoder(nn.Module):
def __init__(self, in_channels, emb_dim, k1=2, k2=1, max_patches=256):
super().__init__()
self.proj = nn.Conv2d(in_channels, emb_dim, kernel_size=(k1, k2), stride=(k1, k2))
self.pos_emb = nn.Parameter(torch.zeros(max_patches, emb_dim))
nn.init.trunc_normal_(self.pos_emb, std=0.02)
def forward(self, F):
x = self.proj(F)
B, D, Hp, Wp = x.shape
N = Hp * Wp
x = x.flatten(2).transpose(1, 2)
x = x + self.pos_emb[:N].unsqueeze(0)
return x, N
def make_encoder(emb_dim=384, nhead=8, num_layers=3, dim_feedforward=1024, dropout=0.1):
enc_layer = nn.TransformerEncoderLayer(d_model=emb_dim, nhead=nhead,
dim_feedforward=dim_feedforward,
dropout=dropout, activation='relu')
return nn.TransformerEncoder(enc_layer, num_layers=num_layers)
class TransformerDecoderWrapper(nn.Module):
def __init__(self, vocab_size, emb_dim, nhead=8, num_layers=3, pad_idx=0, max_len=256):
super().__init__()
self.tok_emb = nn.Embedding(vocab_size, emb_dim, padding_idx=pad_idx)
dec_layer = nn.TransformerDecoderLayer(d_model=emb_dim, nhead=nhead, dim_feedforward=emb_dim*4, dropout=0.1)
self.decoder = nn.TransformerDecoder(dec_layer, num_layers=num_layers)
self.pos_emb = nn.Parameter(torch.zeros(max_len, emb_dim))
nn.init.trunc_normal_(self.pos_emb, std=0.1)
self.out_proj = nn.Linear(emb_dim, vocab_size)
self.pad_idx = pad_idx
def generate_square_subsequent_mask(self, sz):
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
return mask
def forward(self, tgt_tokens, memory, memory_key_padding_mask):
B, T = tgt_tokens.size()
device = tgt_tokens.device
tok = self.tok_emb(tgt_tokens)
pos = self.pos_emb[:T,:].unsqueeze(0).expand(B,-1,-1)
tgt = (tok + pos).transpose(0,1)
tgt_key_padding_mask = (tgt_tokens == self.pad_idx)
if memory_key_padding_mask is not None:
memory_key_padding_mask = memory_key_padding_mask.bool()
tgt_mask = self.generate_square_subsequent_mask(T).to(device)
mem = memory.transpose(0,1)
dec_out = self.decoder(tgt, mem, tgt_mask=tgt_mask, tgt_key_padding_mask=tgt_key_padding_mask, memory_key_padding_mask=memory_key_padding_mask)
return self.out_proj(dec_out.transpose(0,1))
# ==========================================
# 2. MAIN MODEL WRAPPER
# ==========================================
class KhmerOCR(PreTrainedModel):
config_class = KhmerOCRConfig
def __init__(self, config):
super().__init__(config)
self.vocab_size = config.vocab_size
self.pad_idx = config.pad_idx
self.emb_dim = config.emb_dim
self.cnn = ImprovedFeatureExtractor()
self.patch = PatchEncoder(512, emb_dim=self.emb_dim, k1=2, k2=1)
self.enc = make_encoder(emb_dim=self.emb_dim, nhead=config.nhead, num_layers=config.num_encoder_layers)
self.global_pos = nn.Parameter(torch.zeros(config.max_global_len, self.emb_dim))
nn.init.trunc_normal_(self.global_pos, std=0.02)
self.context_bilstm = nn.LSTM(
input_size=self.emb_dim,
hidden_size=self.emb_dim // 2,
num_layers=1,
batch_first=True,
bidirectional=True
)
self.dec = TransformerDecoderWrapper(self.vocab_size, emb_dim=self.emb_dim, nhead=config.nhead,
num_layers=config.num_decoder_layers, pad_idx=self.pad_idx)
def forward(self, chunk_lists, tgt_tokens=None):
# 1. Flatten
chunk_sizes = [len(c) for c in chunk_lists]
flat_input_list = [chunk for img_chunks in chunk_lists for chunk in img_chunks]
flat_input = torch.stack(flat_input_list)
# 2. Pipeline
f = self.cnn(flat_input)
p, _ = self.patch(f)
p = p.transpose(0, 1).contiguous()
enc_out = self.enc(p)
enc_out = enc_out.transpose(0, 1)
# 3. Merge
batch_encoded_list = []
cursor = 0
feature_dim = enc_out.size(-1)
for size in chunk_sizes:
img_chunks = enc_out[cursor : cursor + size]
merged_seq = img_chunks.reshape(-1, feature_dim)
batch_encoded_list.append(merged_seq)
cursor += size
# 4. Pad & Global Pos
memory = pad_sequence(batch_encoded_list, batch_first=True, padding_value=0.0)
B, T, _ = memory.shape
limit = min(T, self.global_pos.size(0))
pos_emb = self.global_pos[:limit, :].unsqueeze(0)
if T > self.global_pos.size(0):
memory = memory[:, :limit, :] + pos_emb
T = limit
else:
memory = memory + pos_emb
# 5. BiLSTM
self.context_bilstm.flatten_parameters()
memory, _ = self.context_bilstm(memory)
# If inference (no targets), return memory for search
if tgt_tokens is None:
return memory
# 6. Decoder
memory_key_padding_mask = torch.ones((B, T), dtype=torch.bool, device=memory.device)
for i, seq in enumerate(batch_encoded_list):
valid_len = min(seq.shape[0], T)
memory_key_padding_mask[i, :valid_len] = False
logits = self.dec(tgt_tokens, memory, memory_key_padding_mask)
return logits