Receipt_OCR / model_def.py
RickyGM15's picture
Upload folder using huggingface_hub
e141a7d verified
import torch
import torch.nn as nn
import torchvision.models as models
import math
class OCRBackbone(nn.Module):
def __init__(self, output_dim=512, pretrained=True, in_channels=3):
super().__init__()
base_model = models.resnet34(weights="IMAGENET1K_V1" if pretrained else None)
# Tùy chỉnh conv1 nếu ảnh là grayscale
if in_channels == 1:
self.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
with torch.no_grad():
self.conv1.weight.data = base_model.conv1.weight.data.mean(dim=1, keepdim=True)
else:
self.conv1 = base_model.conv1
self.bn1 = base_model.bn1
self.relu = base_model.relu
self.maxpool = base_model.maxpool
self.layer1 = base_model.layer1 # giữ nguyên
self.layer2 = base_model.layer2
self.layer3 = base_model.layer3
self.layer4 = base_model.layer4
self.pool = nn.AdaptiveAvgPool2d((1, None)) # H = 1, giữ W'
self.output_conv = nn.Conv2d(512, output_dim, kernel_size=1)
self.norm = nn.LayerNorm(output_dim)
def forward(self, x):
"""
Args:
x: [B, C, H, W]
Returns:
features: [B, T, D]
"""
x = self.conv1(x) # [B, 64, H/2, W/2]
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x) # [B, 64, H/4, W/4]
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.pool(x) # [B, 512, 1, W']
x = self.output_conv(x) # [B, output_dim, 1, W']
x = x.squeeze(2).permute(0, 2, 1) # [B, W', output_dim]
return self.norm(x)
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=5000):
super().__init__()
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(1) # (max_len, 1, d_model)
self.register_buffer('pe', pe)
def forward(self, x): # x: (seq_len, batch_size, d_model)
return x + self.pe[:x.size(0)]
class TransformerDecoder(nn.Module):
def __init__(self, vocab_size, embed_dim=512, num_heads=8, num_layers=3, dropout=0.1, max_len=100):
super().__init__()
self.embedding = nn.Embedding(vocab_size, embed_dim)
self.pos_encoding = PositionalEncoding(embed_dim, max_len=max_len)
decoder_layer = nn.TransformerDecoderLayer(
d_model=embed_dim, nhead=num_heads, dim_feedforward=embed_dim*4, dropout=dropout
)
self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_layers)
self.fc = nn.Linear(embed_dim, vocab_size)
self.embed_scale = embed_dim ** 0.5
def generate_square_subsequent_mask(self, sz):
# Tạo mask tam giác trên cho mô hình không thể nhìn thấy tương lai
return torch.triu(torch.full((sz, sz), float('-inf')), diagonal=1)
def forward(self, tgt, memory, tgt_mask=None, tgt_padding_mask=None):
"""
Args:
tgt: (B, T)
memory: (B, S, D)
tgt_mask: (T, T) hoặc None
tgt_padding_mask: (B, T) hoặc None
"""
# 1) embedding + pos encoding như cũ
tgt_emb = self.embedding(tgt) * self.embed_scale
tgt_emb = tgt_emb.transpose(0, 1) # (T, B, D)
tgt_emb = self.pos_encoding(tgt_emb) # (T, B, D)
# 2) nếu ngoài không truyền tgt_mask thì tạo inside
if tgt_mask is None:
tgt_mask = self.generate_square_subsequent_mask(tgt.size(1)).to(tgt.device)
# 3) gọi TransformerDecoder
out = self.decoder(
tgt=tgt_emb,
memory=memory.transpose(0, 1), # (S, B, D)
tgt_mask=tgt_mask,
tgt_key_padding_mask=tgt_padding_mask
)
out = out.transpose(0, 1) # (B, T, D)
return self.fc(out)
class OCRModel(nn.Module):
def __init__(
self,
vocab_size,
encoder=None,
encoder_dim=512,
embed_dim=512,
num_heads=8,
num_layers=3,
dropout=0.1,
sos_token_id=1,
eos_token_id=2,
max_len=100
):
super().__init__()
# encoder giữ nguyên
self.encoder = encoder if encoder is not None else OCRBackbone(output_dim=encoder_dim)
# projection nếu cần
self.encoder_proj = nn.Linear(encoder_dim, embed_dim) if encoder_dim != embed_dim else nn.Identity()
# transformer‐decoder giữ nguyên
self.decoder = TransformerDecoder(
vocab_size=vocab_size,
embed_dim=embed_dim,
num_heads=num_heads,
num_layers=num_layers,
dropout=dropout,
max_len=max_len
)
# self.output_layer = nn.Linear(embed_dim, vocab_size)
self.sos_token_id = sos_token_id
self.eos_token_id = eos_token_id
self.max_len = max_len
def generate_square_subsequent_mask(self, sz: int) -> torch.Tensor:
""" Mask upper-triangular for causal attention. """
return torch.triu(torch.full((sz, sz), float('-inf')), diagonal=1)
def forward(self,
images: torch.Tensor,
tgt_input: torch.Tensor,
attention_mask: torch.Tensor
) -> torch.Tensor:
"""
Args:
images: [B, C, H, W]
tgt_input: [B, T] – token ids (including <SOS>)
attention_mask: [B, T] – True for real tokens, False for PAD
Returns:
logits: [B, T, vocab_size]
"""
# 1) Encode hình ảnh
memory = self.encoder(images) # [B, S, encoder_dim]
memory = self.encoder_proj(memory) # [B, S, embed_dim]
# 2) Tạo causal mask cho decoder (để không nhìn vào tương lai)
T = tgt_input.size(1)
tgt_mask = self.generate_square_subsequent_mask(T).to(images.device) # [T, T]
# 3) Tạo padding mask (True = vị trí pad sẽ bị ignore)
# TransformerDecoder expects tgt_key_padding_mask shape [B, T], True=pad
tgt_key_padding_mask = ~attention_mask # invert: False-> real token, True->pad
# 4) Decode
# Note: our TransformerDecoder.forward signature is (tgt, memory, tgt_mask, ..., tgt_key_padding_mask)
dec_out = self.decoder(
tgt=tgt_input,
memory=memory,
tgt_mask=tgt_mask,
tgt_padding_mask=tgt_key_padding_mask
) # [B, T, embed_dim]
# 5) Project to vocab
# logits = self.output_layer(dec_out) # [B, T, vocab_size]
# return logits
return dec_out
@torch.no_grad()
def predict(self, images, max_len=None, beam_size=5, length_penalty=0.7):
"""
Beam search decoding.
images: Tensor [B, C, H, W]
max_len: int, tối đa độ dài output (mặc định self.max_len)
beam_size: int, số beam
length_penalty: float, điều chỉnh score beam
Trả về: Tensor [B, L_pred] gồm token_ids (ko có <SOS>, pad bằng <EOS>)
"""
self.eval()
max_len = max_len or self.max_len
device = images.device
# 1) Encode + project
memory = self.encoder(images) # [B, S, encoder_dim]
memory = self.encoder_proj(memory) # [B, S, embed_dim]
B = memory.size(0)
final_outputs = []
# 2) Beam search từng mẫu
for b in range(B):
mem = memory[b : b+1] # [1, S, D]
beams = [([self.sos_token_id], 0.0)]
for _ in range(max_len):
candidates = []
for seq, score in beams:
if seq[-1] == self.eos_token_id:
candidates.append((seq, score))
continue
tgt = torch.tensor(seq, device=device).unsqueeze(0) # [1, T]
# decoder sẽ tự tạo causal mask khi bạn không truyền tgt_mask
logits = self.decoder(tgt, mem) # [1, T, vocab_size]
logp = torch.log_softmax(logits[:, -1, :], dim=-1) # [1, vocab_size]
topk_logp, topk_ids = logp.topk(beam_size, dim=-1) # [1, beam_size]
for i in range(beam_size):
candidates.append((
seq + [int(topk_ids[0, i])],
score + float(topk_logp[0, i])
))
# length normalization và giữ lại beam_size best
beams = sorted(
candidates,
key=lambda x: x[1] / ((len(x[0]) ** length_penalty) if length_penalty > 0 else 1.0),
reverse=True
)[:beam_size]
# nếu tất cả beams đã kết thúc bằng <EOS>, dừng sớm
if all(seq[-1] == self.eos_token_id for seq, _ in beams):
break
# Chọn beam tốt nhất, loại <SOS> + cắt sau <EOS>
best_seq = beams[0][0]
if self.eos_token_id in best_seq:
idx = best_seq.index(self.eos_token_id)
best_seq = best_seq[1:idx]
else:
best_seq = best_seq[1:]
final_outputs.append(best_seq)
# 3) Padding về tensor [B, L_max_pred]
max_p = max(len(s) for s in final_outputs)
out = torch.full((B, max_p), self.eos_token_id, dtype=torch.long, device=device)
for i, seq in enumerate(final_outputs):
out[i, : len(seq)] = torch.tensor(seq, device=device)
return out
@torch.no_grad()
def predict_greedy(self, images, max_len=None):
"""
Greedy decoding: đơn giản, nhanh, dùng riêng cho evaluate()
Trả về: Tensor [B, L_pred]
"""
self.eval()
max_len = max_len or self.max_len
device = images.device
memory = self.encoder(images) # [B, S, encoder_dim]
memory = self.encoder_proj(memory) # [B, S, embed_dim]
B = memory.size(0)
final_outputs = []
for b in range(B):
mem = memory[b:b+1]
seq = [self.sos_token_id]
for _ in range(max_len):
tgt = torch.tensor(seq, device=device).unsqueeze(0)
logits = self.decoder(tgt, mem)
next_token = logits[:, -1, :].argmax(dim=-1).item()
seq.append(next_token)
if next_token == self.eos_token_id:
break
# Cắt chuỗi
if self.eos_token_id in seq:
idx = seq.index(self.eos_token_id)
seq = seq[1:idx]
else:
seq = seq[1:]
final_outputs.append(seq)
max_len_pred = max(len(s) for s in final_outputs)
out = torch.full((B, max_len_pred), self.eos_token_id, dtype=torch.long, device=device)
for i, seq in enumerate(final_outputs):
out[i, :len(seq)] = torch.tensor(seq, device=device)
return out
class EarlyStopping:
def __init__(self, patience=3, min_delta=0):
self.patience = patience
self.min_delta = min_delta
self.counter = 0
self.best_loss = float('inf')
def __call__(self, val_loss):
if val_loss < self.best_loss - self.min_delta:
self.best_loss = val_loss
self.counter = 0
else:
self.counter += 1
if self.counter >= self.patience:
return True # Dừng huấn luyện
return False # Tiếp tục