import torch import torch.nn as nn import torchvision.models as models import math class OCRBackbone(nn.Module): def __init__(self, output_dim=512, pretrained=True, in_channels=3): super().__init__() base_model = models.resnet34(weights="IMAGENET1K_V1" if pretrained else None) # Tùy chỉnh conv1 nếu ảnh là grayscale if in_channels == 1: self.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False) with torch.no_grad(): self.conv1.weight.data = base_model.conv1.weight.data.mean(dim=1, keepdim=True) else: self.conv1 = base_model.conv1 self.bn1 = base_model.bn1 self.relu = base_model.relu self.maxpool = base_model.maxpool self.layer1 = base_model.layer1 # giữ nguyên self.layer2 = base_model.layer2 self.layer3 = base_model.layer3 self.layer4 = base_model.layer4 self.pool = nn.AdaptiveAvgPool2d((1, None)) # H = 1, giữ W' self.output_conv = nn.Conv2d(512, output_dim, kernel_size=1) self.norm = nn.LayerNorm(output_dim) def forward(self, x): """ Args: x: [B, C, H, W] Returns: features: [B, T, D] """ x = self.conv1(x) # [B, 64, H/2, W/2] x = self.bn1(x) x = self.relu(x) x = self.maxpool(x) # [B, 64, H/4, W/4] x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.layer4(x) x = self.pool(x) # [B, 512, 1, W'] x = self.output_conv(x) # [B, output_dim, 1, W'] x = x.squeeze(2).permute(0, 2, 1) # [B, W', output_dim] return self.norm(x) class PositionalEncoding(nn.Module): def __init__(self, d_model, max_len=5000): super().__init__() pe = torch.zeros(max_len, d_model) position = torch.arange(0, max_len).unsqueeze(1) div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) pe = pe.unsqueeze(1) # (max_len, 1, d_model) self.register_buffer('pe', pe) def forward(self, x): # x: (seq_len, batch_size, d_model) return x + self.pe[:x.size(0)] class TransformerDecoder(nn.Module): def __init__(self, vocab_size, embed_dim=512, num_heads=8, num_layers=3, dropout=0.1, max_len=100): super().__init__() self.embedding = nn.Embedding(vocab_size, embed_dim) self.pos_encoding = PositionalEncoding(embed_dim, max_len=max_len) decoder_layer = nn.TransformerDecoderLayer( d_model=embed_dim, nhead=num_heads, dim_feedforward=embed_dim*4, dropout=dropout ) self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_layers) self.fc = nn.Linear(embed_dim, vocab_size) self.embed_scale = embed_dim ** 0.5 def generate_square_subsequent_mask(self, sz): # Tạo mask tam giác trên cho mô hình không thể nhìn thấy tương lai return torch.triu(torch.full((sz, sz), float('-inf')), diagonal=1) def forward(self, tgt, memory, tgt_mask=None, tgt_padding_mask=None): """ Args: tgt: (B, T) memory: (B, S, D) tgt_mask: (T, T) hoặc None tgt_padding_mask: (B, T) hoặc None """ # 1) embedding + pos encoding như cũ tgt_emb = self.embedding(tgt) * self.embed_scale tgt_emb = tgt_emb.transpose(0, 1) # (T, B, D) tgt_emb = self.pos_encoding(tgt_emb) # (T, B, D) # 2) nếu ngoài không truyền tgt_mask thì tạo inside if tgt_mask is None: tgt_mask = self.generate_square_subsequent_mask(tgt.size(1)).to(tgt.device) # 3) gọi TransformerDecoder out = self.decoder( tgt=tgt_emb, memory=memory.transpose(0, 1), # (S, B, D) tgt_mask=tgt_mask, tgt_key_padding_mask=tgt_padding_mask ) out = out.transpose(0, 1) # (B, T, D) return self.fc(out) class OCRModel(nn.Module): def __init__( self, vocab_size, encoder=None, encoder_dim=512, embed_dim=512, num_heads=8, num_layers=3, dropout=0.1, sos_token_id=1, eos_token_id=2, max_len=100 ): super().__init__() # encoder giữ nguyên self.encoder = encoder if encoder is not None else OCRBackbone(output_dim=encoder_dim) # projection nếu cần self.encoder_proj = nn.Linear(encoder_dim, embed_dim) if encoder_dim != embed_dim else nn.Identity() # transformer‐decoder giữ nguyên self.decoder = TransformerDecoder( vocab_size=vocab_size, embed_dim=embed_dim, num_heads=num_heads, num_layers=num_layers, dropout=dropout, max_len=max_len ) # self.output_layer = nn.Linear(embed_dim, vocab_size) self.sos_token_id = sos_token_id self.eos_token_id = eos_token_id self.max_len = max_len def generate_square_subsequent_mask(self, sz: int) -> torch.Tensor: """ Mask upper-triangular for causal attention. """ return torch.triu(torch.full((sz, sz), float('-inf')), diagonal=1) def forward(self, images: torch.Tensor, tgt_input: torch.Tensor, attention_mask: torch.Tensor ) -> torch.Tensor: """ Args: images: [B, C, H, W] tgt_input: [B, T] – token ids (including ) attention_mask: [B, T] – True for real tokens, False for PAD Returns: logits: [B, T, vocab_size] """ # 1) Encode hình ảnh memory = self.encoder(images) # [B, S, encoder_dim] memory = self.encoder_proj(memory) # [B, S, embed_dim] # 2) Tạo causal mask cho decoder (để không nhìn vào tương lai) T = tgt_input.size(1) tgt_mask = self.generate_square_subsequent_mask(T).to(images.device) # [T, T] # 3) Tạo padding mask (True = vị trí pad sẽ bị ignore) # TransformerDecoder expects tgt_key_padding_mask shape [B, T], True=pad tgt_key_padding_mask = ~attention_mask # invert: False-> real token, True->pad # 4) Decode # Note: our TransformerDecoder.forward signature is (tgt, memory, tgt_mask, ..., tgt_key_padding_mask) dec_out = self.decoder( tgt=tgt_input, memory=memory, tgt_mask=tgt_mask, tgt_padding_mask=tgt_key_padding_mask ) # [B, T, embed_dim] # 5) Project to vocab # logits = self.output_layer(dec_out) # [B, T, vocab_size] # return logits return dec_out @torch.no_grad() def predict(self, images, max_len=None, beam_size=5, length_penalty=0.7): """ Beam search decoding. images: Tensor [B, C, H, W] max_len: int, tối đa độ dài output (mặc định self.max_len) beam_size: int, số beam length_penalty: float, điều chỉnh score beam Trả về: Tensor [B, L_pred] gồm token_ids (ko có , pad bằng ) """ self.eval() max_len = max_len or self.max_len device = images.device # 1) Encode + project memory = self.encoder(images) # [B, S, encoder_dim] memory = self.encoder_proj(memory) # [B, S, embed_dim] B = memory.size(0) final_outputs = [] # 2) Beam search từng mẫu for b in range(B): mem = memory[b : b+1] # [1, S, D] beams = [([self.sos_token_id], 0.0)] for _ in range(max_len): candidates = [] for seq, score in beams: if seq[-1] == self.eos_token_id: candidates.append((seq, score)) continue tgt = torch.tensor(seq, device=device).unsqueeze(0) # [1, T] # decoder sẽ tự tạo causal mask khi bạn không truyền tgt_mask logits = self.decoder(tgt, mem) # [1, T, vocab_size] logp = torch.log_softmax(logits[:, -1, :], dim=-1) # [1, vocab_size] topk_logp, topk_ids = logp.topk(beam_size, dim=-1) # [1, beam_size] for i in range(beam_size): candidates.append(( seq + [int(topk_ids[0, i])], score + float(topk_logp[0, i]) )) # length normalization và giữ lại beam_size best beams = sorted( candidates, key=lambda x: x[1] / ((len(x[0]) ** length_penalty) if length_penalty > 0 else 1.0), reverse=True )[:beam_size] # nếu tất cả beams đã kết thúc bằng , dừng sớm if all(seq[-1] == self.eos_token_id for seq, _ in beams): break # Chọn beam tốt nhất, loại + cắt sau best_seq = beams[0][0] if self.eos_token_id in best_seq: idx = best_seq.index(self.eos_token_id) best_seq = best_seq[1:idx] else: best_seq = best_seq[1:] final_outputs.append(best_seq) # 3) Padding về tensor [B, L_max_pred] max_p = max(len(s) for s in final_outputs) out = torch.full((B, max_p), self.eos_token_id, dtype=torch.long, device=device) for i, seq in enumerate(final_outputs): out[i, : len(seq)] = torch.tensor(seq, device=device) return out @torch.no_grad() def predict_greedy(self, images, max_len=None): """ Greedy decoding: đơn giản, nhanh, dùng riêng cho evaluate() Trả về: Tensor [B, L_pred] """ self.eval() max_len = max_len or self.max_len device = images.device memory = self.encoder(images) # [B, S, encoder_dim] memory = self.encoder_proj(memory) # [B, S, embed_dim] B = memory.size(0) final_outputs = [] for b in range(B): mem = memory[b:b+1] seq = [self.sos_token_id] for _ in range(max_len): tgt = torch.tensor(seq, device=device).unsqueeze(0) logits = self.decoder(tgt, mem) next_token = logits[:, -1, :].argmax(dim=-1).item() seq.append(next_token) if next_token == self.eos_token_id: break # Cắt chuỗi if self.eos_token_id in seq: idx = seq.index(self.eos_token_id) seq = seq[1:idx] else: seq = seq[1:] final_outputs.append(seq) max_len_pred = max(len(s) for s in final_outputs) out = torch.full((B, max_len_pred), self.eos_token_id, dtype=torch.long, device=device) for i, seq in enumerate(final_outputs): out[i, :len(seq)] = torch.tensor(seq, device=device) return out class EarlyStopping: def __init__(self, patience=3, min_delta=0): self.patience = patience self.min_delta = min_delta self.counter = 0 self.best_loss = float('inf') def __call__(self, val_loss): if val_loss < self.best_loss - self.min_delta: self.best_loss = val_loss self.counter = 0 else: self.counter += 1 if self.counter >= self.patience: return True # Dừng huấn luyện return False # Tiếp tục