Spaces:
Build error
Build error
| import torch | |
| import torch.nn as nn | |
| import torchvision.models as models | |
| import math | |
| class OCRBackbone(nn.Module): | |
| def __init__(self, output_dim=512, pretrained=True, in_channels=3): | |
| super().__init__() | |
| base_model = models.resnet34(weights="IMAGENET1K_V1" if pretrained else None) | |
| # Tùy chỉnh conv1 nếu ảnh là grayscale | |
| if in_channels == 1: | |
| self.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False) | |
| with torch.no_grad(): | |
| self.conv1.weight.data = base_model.conv1.weight.data.mean(dim=1, keepdim=True) | |
| else: | |
| self.conv1 = base_model.conv1 | |
| self.bn1 = base_model.bn1 | |
| self.relu = base_model.relu | |
| self.maxpool = base_model.maxpool | |
| self.layer1 = base_model.layer1 # giữ nguyên | |
| self.layer2 = base_model.layer2 | |
| self.layer3 = base_model.layer3 | |
| self.layer4 = base_model.layer4 | |
| self.pool = nn.AdaptiveAvgPool2d((1, None)) # H = 1, giữ W' | |
| self.output_conv = nn.Conv2d(512, output_dim, kernel_size=1) | |
| self.norm = nn.LayerNorm(output_dim) | |
| def forward(self, x): | |
| """ | |
| Args: | |
| x: [B, C, H, W] | |
| Returns: | |
| features: [B, T, D] | |
| """ | |
| x = self.conv1(x) # [B, 64, H/2, W/2] | |
| x = self.bn1(x) | |
| x = self.relu(x) | |
| x = self.maxpool(x) # [B, 64, H/4, W/4] | |
| x = self.layer1(x) | |
| x = self.layer2(x) | |
| x = self.layer3(x) | |
| x = self.layer4(x) | |
| x = self.pool(x) # [B, 512, 1, W'] | |
| x = self.output_conv(x) # [B, output_dim, 1, W'] | |
| x = x.squeeze(2).permute(0, 2, 1) # [B, W', output_dim] | |
| return self.norm(x) | |
| class PositionalEncoding(nn.Module): | |
| def __init__(self, d_model, max_len=5000): | |
| super().__init__() | |
| pe = torch.zeros(max_len, d_model) | |
| position = torch.arange(0, max_len).unsqueeze(1) | |
| div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)) | |
| pe[:, 0::2] = torch.sin(position * div_term) | |
| pe[:, 1::2] = torch.cos(position * div_term) | |
| pe = pe.unsqueeze(1) # (max_len, 1, d_model) | |
| self.register_buffer('pe', pe) | |
| def forward(self, x): # x: (seq_len, batch_size, d_model) | |
| return x + self.pe[:x.size(0)] | |
| class TransformerDecoder(nn.Module): | |
| def __init__(self, vocab_size, embed_dim=512, num_heads=8, num_layers=3, dropout=0.1, max_len=100): | |
| super().__init__() | |
| self.embedding = nn.Embedding(vocab_size, embed_dim) | |
| self.pos_encoding = PositionalEncoding(embed_dim, max_len=max_len) | |
| decoder_layer = nn.TransformerDecoderLayer( | |
| d_model=embed_dim, nhead=num_heads, dim_feedforward=embed_dim*4, dropout=dropout | |
| ) | |
| self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_layers) | |
| self.fc = nn.Linear(embed_dim, vocab_size) | |
| self.embed_scale = embed_dim ** 0.5 | |
| def generate_square_subsequent_mask(self, sz): | |
| # Tạo mask tam giác trên cho mô hình không thể nhìn thấy tương lai | |
| return torch.triu(torch.full((sz, sz), float('-inf')), diagonal=1) | |
| def forward(self, tgt, memory, tgt_mask=None, tgt_padding_mask=None): | |
| """ | |
| Args: | |
| tgt: (B, T) | |
| memory: (B, S, D) | |
| tgt_mask: (T, T) hoặc None | |
| tgt_padding_mask: (B, T) hoặc None | |
| """ | |
| # 1) embedding + pos encoding như cũ | |
| tgt_emb = self.embedding(tgt) * self.embed_scale | |
| tgt_emb = tgt_emb.transpose(0, 1) # (T, B, D) | |
| tgt_emb = self.pos_encoding(tgt_emb) # (T, B, D) | |
| # 2) nếu ngoài không truyền tgt_mask thì tạo inside | |
| if tgt_mask is None: | |
| tgt_mask = self.generate_square_subsequent_mask(tgt.size(1)).to(tgt.device) | |
| # 3) gọi TransformerDecoder | |
| out = self.decoder( | |
| tgt=tgt_emb, | |
| memory=memory.transpose(0, 1), # (S, B, D) | |
| tgt_mask=tgt_mask, | |
| tgt_key_padding_mask=tgt_padding_mask | |
| ) | |
| out = out.transpose(0, 1) # (B, T, D) | |
| return self.fc(out) | |
| class OCRModel(nn.Module): | |
| def __init__( | |
| self, | |
| vocab_size, | |
| encoder=None, | |
| encoder_dim=512, | |
| embed_dim=512, | |
| num_heads=8, | |
| num_layers=3, | |
| dropout=0.1, | |
| sos_token_id=1, | |
| eos_token_id=2, | |
| max_len=100 | |
| ): | |
| super().__init__() | |
| # encoder giữ nguyên | |
| self.encoder = encoder if encoder is not None else OCRBackbone(output_dim=encoder_dim) | |
| # projection nếu cần | |
| self.encoder_proj = nn.Linear(encoder_dim, embed_dim) if encoder_dim != embed_dim else nn.Identity() | |
| # transformer‐decoder giữ nguyên | |
| self.decoder = TransformerDecoder( | |
| vocab_size=vocab_size, | |
| embed_dim=embed_dim, | |
| num_heads=num_heads, | |
| num_layers=num_layers, | |
| dropout=dropout, | |
| max_len=max_len | |
| ) | |
| # self.output_layer = nn.Linear(embed_dim, vocab_size) | |
| self.sos_token_id = sos_token_id | |
| self.eos_token_id = eos_token_id | |
| self.max_len = max_len | |
| def generate_square_subsequent_mask(self, sz: int) -> torch.Tensor: | |
| """ Mask upper-triangular for causal attention. """ | |
| return torch.triu(torch.full((sz, sz), float('-inf')), diagonal=1) | |
| def forward(self, | |
| images: torch.Tensor, | |
| tgt_input: torch.Tensor, | |
| attention_mask: torch.Tensor | |
| ) -> torch.Tensor: | |
| """ | |
| Args: | |
| images: [B, C, H, W] | |
| tgt_input: [B, T] – token ids (including <SOS>) | |
| attention_mask: [B, T] – True for real tokens, False for PAD | |
| Returns: | |
| logits: [B, T, vocab_size] | |
| """ | |
| # 1) Encode hình ảnh | |
| memory = self.encoder(images) # [B, S, encoder_dim] | |
| memory = self.encoder_proj(memory) # [B, S, embed_dim] | |
| # 2) Tạo causal mask cho decoder (để không nhìn vào tương lai) | |
| T = tgt_input.size(1) | |
| tgt_mask = self.generate_square_subsequent_mask(T).to(images.device) # [T, T] | |
| # 3) Tạo padding mask (True = vị trí pad sẽ bị ignore) | |
| # TransformerDecoder expects tgt_key_padding_mask shape [B, T], True=pad | |
| tgt_key_padding_mask = ~attention_mask # invert: False-> real token, True->pad | |
| # 4) Decode | |
| # Note: our TransformerDecoder.forward signature is (tgt, memory, tgt_mask, ..., tgt_key_padding_mask) | |
| dec_out = self.decoder( | |
| tgt=tgt_input, | |
| memory=memory, | |
| tgt_mask=tgt_mask, | |
| tgt_padding_mask=tgt_key_padding_mask | |
| ) # [B, T, embed_dim] | |
| # 5) Project to vocab | |
| # logits = self.output_layer(dec_out) # [B, T, vocab_size] | |
| # return logits | |
| return dec_out | |
| def predict(self, images, max_len=None, beam_size=5, length_penalty=0.7): | |
| """ | |
| Beam search decoding. | |
| images: Tensor [B, C, H, W] | |
| max_len: int, tối đa độ dài output (mặc định self.max_len) | |
| beam_size: int, số beam | |
| length_penalty: float, điều chỉnh score beam | |
| Trả về: Tensor [B, L_pred] gồm token_ids (ko có <SOS>, pad bằng <EOS>) | |
| """ | |
| self.eval() | |
| max_len = max_len or self.max_len | |
| device = images.device | |
| # 1) Encode + project | |
| memory = self.encoder(images) # [B, S, encoder_dim] | |
| memory = self.encoder_proj(memory) # [B, S, embed_dim] | |
| B = memory.size(0) | |
| final_outputs = [] | |
| # 2) Beam search từng mẫu | |
| for b in range(B): | |
| mem = memory[b : b+1] # [1, S, D] | |
| beams = [([self.sos_token_id], 0.0)] | |
| for _ in range(max_len): | |
| candidates = [] | |
| for seq, score in beams: | |
| if seq[-1] == self.eos_token_id: | |
| candidates.append((seq, score)) | |
| continue | |
| tgt = torch.tensor(seq, device=device).unsqueeze(0) # [1, T] | |
| # decoder sẽ tự tạo causal mask khi bạn không truyền tgt_mask | |
| logits = self.decoder(tgt, mem) # [1, T, vocab_size] | |
| logp = torch.log_softmax(logits[:, -1, :], dim=-1) # [1, vocab_size] | |
| topk_logp, topk_ids = logp.topk(beam_size, dim=-1) # [1, beam_size] | |
| for i in range(beam_size): | |
| candidates.append(( | |
| seq + [int(topk_ids[0, i])], | |
| score + float(topk_logp[0, i]) | |
| )) | |
| # length normalization và giữ lại beam_size best | |
| beams = sorted( | |
| candidates, | |
| key=lambda x: x[1] / ((len(x[0]) ** length_penalty) if length_penalty > 0 else 1.0), | |
| reverse=True | |
| )[:beam_size] | |
| # nếu tất cả beams đã kết thúc bằng <EOS>, dừng sớm | |
| if all(seq[-1] == self.eos_token_id for seq, _ in beams): | |
| break | |
| # Chọn beam tốt nhất, loại <SOS> + cắt sau <EOS> | |
| best_seq = beams[0][0] | |
| if self.eos_token_id in best_seq: | |
| idx = best_seq.index(self.eos_token_id) | |
| best_seq = best_seq[1:idx] | |
| else: | |
| best_seq = best_seq[1:] | |
| final_outputs.append(best_seq) | |
| # 3) Padding về tensor [B, L_max_pred] | |
| max_p = max(len(s) for s in final_outputs) | |
| out = torch.full((B, max_p), self.eos_token_id, dtype=torch.long, device=device) | |
| for i, seq in enumerate(final_outputs): | |
| out[i, : len(seq)] = torch.tensor(seq, device=device) | |
| return out | |
| def predict_greedy(self, images, max_len=None): | |
| """ | |
| Greedy decoding: đơn giản, nhanh, dùng riêng cho evaluate() | |
| Trả về: Tensor [B, L_pred] | |
| """ | |
| self.eval() | |
| max_len = max_len or self.max_len | |
| device = images.device | |
| memory = self.encoder(images) # [B, S, encoder_dim] | |
| memory = self.encoder_proj(memory) # [B, S, embed_dim] | |
| B = memory.size(0) | |
| final_outputs = [] | |
| for b in range(B): | |
| mem = memory[b:b+1] | |
| seq = [self.sos_token_id] | |
| for _ in range(max_len): | |
| tgt = torch.tensor(seq, device=device).unsqueeze(0) | |
| logits = self.decoder(tgt, mem) | |
| next_token = logits[:, -1, :].argmax(dim=-1).item() | |
| seq.append(next_token) | |
| if next_token == self.eos_token_id: | |
| break | |
| # Cắt chuỗi | |
| if self.eos_token_id in seq: | |
| idx = seq.index(self.eos_token_id) | |
| seq = seq[1:idx] | |
| else: | |
| seq = seq[1:] | |
| final_outputs.append(seq) | |
| max_len_pred = max(len(s) for s in final_outputs) | |
| out = torch.full((B, max_len_pred), self.eos_token_id, dtype=torch.long, device=device) | |
| for i, seq in enumerate(final_outputs): | |
| out[i, :len(seq)] = torch.tensor(seq, device=device) | |
| return out | |
| class EarlyStopping: | |
| def __init__(self, patience=3, min_delta=0): | |
| self.patience = patience | |
| self.min_delta = min_delta | |
| self.counter = 0 | |
| self.best_loss = float('inf') | |
| def __call__(self, val_loss): | |
| if val_loss < self.best_loss - self.min_delta: | |
| self.best_loss = val_loss | |
| self.counter = 0 | |
| else: | |
| self.counter += 1 | |
| if self.counter >= self.patience: | |
| return True # Dừng huấn luyện | |
| return False # Tiếp tục |