# modeling_khmerocr.py import torch import torch.nn as nn from torch.nn.utils.rnn import pad_sequence from transformers import PreTrainedModel from configuration_khmerocr import KhmerOCRConfig # ========================================== # 1. HELPER CLASSES (SequenceSE, CNN, etc.) # ========================================== class SequenceSE(nn.Module): def __init__(self, channels, reduction=16): super(SequenceSE, self).__init__() self.fc = nn.Sequential( nn.Conv1d(channels, channels // reduction, kernel_size=1), nn.ReLU(inplace=True), nn.Conv1d(channels // reduction, channels, kernel_size=1), nn.Sigmoid() ) def forward(self, x): b, c, h, w = x.size() y = torch.mean(x, dim=2).view(b, c, w) y = self.fc(y) y = y.view(b, c, 1, w) return x * y class ImprovedFeatureExtractor(nn.Module): def __init__(self): super().__init__() self.conv1 = nn.Sequential(nn.Conv2d(1, 64, 3, 1, 1), nn.BatchNorm2d(64), nn.ReLU(True)) self.pool1 = nn.MaxPool2d(2, 2) self.conv2 = nn.Sequential(nn.Conv2d(64, 128, 3, 1, 1), nn.BatchNorm2d(128), nn.ReLU(True)) self.pool2 = nn.MaxPool2d(2, 2) self.conv3 = nn.Sequential(nn.Conv2d(128, 256, 3, 1, 1), nn.BatchNorm2d(256), nn.ReLU(True)) self.conv4 = nn.Sequential(nn.Conv2d(256, 256, 3, 1, 1), nn.BatchNorm2d(256), nn.ReLU(True)) self.se3 = SequenceSE(256) self.pool3 = nn.MaxPool2d(kernel_size=(2, 1), stride=(2, 1)) self.conv5 = nn.Sequential(nn.Conv2d(256, 512, 3, 1, 1), nn.BatchNorm2d(512), nn.ReLU(True)) self.conv6 = nn.Sequential(nn.Conv2d(512, 512, 3, 1, 1), nn.BatchNorm2d(512), nn.ReLU(True)) self.se4 = SequenceSE(512) self.pool4 = nn.MaxPool2d(kernel_size=(2, 1), stride=(2, 1)) self.conv7 = nn.Conv2d(512, 512, 3, 1, 1) self.bn7 = nn.BatchNorm2d(512) self.relu7 = nn.ReLU(True) self.se5 = SequenceSE(512) self.final_pool = nn.AdaptiveAvgPool2d((2, 32)) def forward(self, x): x = self.pool1(self.conv1(x)) x = self.pool2(self.conv2(x)) x = self.conv4(self.conv3(x)) x = self.se3(x) x = self.pool3(x) x = self.conv6(self.conv5(x)) x = self.se4(x) x = self.pool4(x) x = self.relu7(self.bn7(self.conv7(x))) x = self.se5(x) x = self.final_pool(x) return x class PatchEncoder(nn.Module): def __init__(self, in_channels, emb_dim, k1=2, k2=1, max_patches=256): super().__init__() self.proj = nn.Conv2d(in_channels, emb_dim, kernel_size=(k1, k2), stride=(k1, k2)) self.pos_emb = nn.Parameter(torch.zeros(max_patches, emb_dim)) nn.init.trunc_normal_(self.pos_emb, std=0.02) def forward(self, F): x = self.proj(F) B, D, Hp, Wp = x.shape N = Hp * Wp x = x.flatten(2).transpose(1, 2) x = x + self.pos_emb[:N].unsqueeze(0) return x, N def make_encoder(emb_dim=384, nhead=8, num_layers=3, dim_feedforward=1024, dropout=0.1): enc_layer = nn.TransformerEncoderLayer(d_model=emb_dim, nhead=nhead, dim_feedforward=dim_feedforward, dropout=dropout, activation='relu') return nn.TransformerEncoder(enc_layer, num_layers=num_layers) class TransformerDecoderWrapper(nn.Module): def __init__(self, vocab_size, emb_dim, nhead=8, num_layers=3, pad_idx=0, max_len=256): super().__init__() self.tok_emb = nn.Embedding(vocab_size, emb_dim, padding_idx=pad_idx) dec_layer = nn.TransformerDecoderLayer(d_model=emb_dim, nhead=nhead, dim_feedforward=emb_dim*4, dropout=0.1) self.decoder = nn.TransformerDecoder(dec_layer, num_layers=num_layers) self.pos_emb = nn.Parameter(torch.zeros(max_len, emb_dim)) nn.init.trunc_normal_(self.pos_emb, std=0.1) self.out_proj = nn.Linear(emb_dim, vocab_size) self.pad_idx = pad_idx def generate_square_subsequent_mask(self, sz): mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1) mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0)) return mask def forward(self, tgt_tokens, memory, memory_key_padding_mask): B, T = tgt_tokens.size() device = tgt_tokens.device tok = self.tok_emb(tgt_tokens) pos = self.pos_emb[:T,:].unsqueeze(0).expand(B,-1,-1) tgt = (tok + pos).transpose(0,1) tgt_key_padding_mask = (tgt_tokens == self.pad_idx) if memory_key_padding_mask is not None: memory_key_padding_mask = memory_key_padding_mask.bool() tgt_mask = self.generate_square_subsequent_mask(T).to(device) mem = memory.transpose(0,1) dec_out = self.decoder(tgt, mem, tgt_mask=tgt_mask, tgt_key_padding_mask=tgt_key_padding_mask, memory_key_padding_mask=memory_key_padding_mask) return self.out_proj(dec_out.transpose(0,1)) # ========================================== # 2. MAIN MODEL WRAPPER # ========================================== class KhmerOCR(PreTrainedModel): config_class = KhmerOCRConfig def __init__(self, config): super().__init__(config) self.vocab_size = config.vocab_size self.pad_idx = config.pad_idx self.emb_dim = config.emb_dim self.cnn = ImprovedFeatureExtractor() self.patch = PatchEncoder(512, emb_dim=self.emb_dim, k1=2, k2=1) self.enc = make_encoder(emb_dim=self.emb_dim, nhead=config.nhead, num_layers=config.num_encoder_layers) self.global_pos = nn.Parameter(torch.zeros(config.max_global_len, self.emb_dim)) nn.init.trunc_normal_(self.global_pos, std=0.02) self.context_bilstm = nn.LSTM( input_size=self.emb_dim, hidden_size=self.emb_dim // 2, num_layers=1, batch_first=True, bidirectional=True ) self.dec = TransformerDecoderWrapper(self.vocab_size, emb_dim=self.emb_dim, nhead=config.nhead, num_layers=config.num_decoder_layers, pad_idx=self.pad_idx) def forward(self, chunk_lists, tgt_tokens=None): # 1. Flatten chunk_sizes = [len(c) for c in chunk_lists] flat_input_list = [chunk for img_chunks in chunk_lists for chunk in img_chunks] flat_input = torch.stack(flat_input_list) # 2. Pipeline f = self.cnn(flat_input) p, _ = self.patch(f) p = p.transpose(0, 1).contiguous() enc_out = self.enc(p) enc_out = enc_out.transpose(0, 1) # 3. Merge batch_encoded_list = [] cursor = 0 feature_dim = enc_out.size(-1) for size in chunk_sizes: img_chunks = enc_out[cursor : cursor + size] merged_seq = img_chunks.reshape(-1, feature_dim) batch_encoded_list.append(merged_seq) cursor += size # 4. Pad & Global Pos memory = pad_sequence(batch_encoded_list, batch_first=True, padding_value=0.0) B, T, _ = memory.shape limit = min(T, self.global_pos.size(0)) pos_emb = self.global_pos[:limit, :].unsqueeze(0) if T > self.global_pos.size(0): memory = memory[:, :limit, :] + pos_emb T = limit else: memory = memory + pos_emb # 5. BiLSTM self.context_bilstm.flatten_parameters() memory, _ = self.context_bilstm(memory) # If inference (no targets), return memory for search if tgt_tokens is None: return memory # 6. Decoder memory_key_padding_mask = torch.ones((B, T), dtype=torch.bool, device=memory.device) for i, seq in enumerate(batch_encoded_list): valid_len = min(seq.shape[0], T) memory_key_padding_mask[i, :valid_len] = False logits = self.dec(tgt_tokens, memory, memory_key_padding_mask) return logits