| | import math |
| | import torch |
| | import torch.nn as nn |
| | from transformers import PreTrainedModel |
| | from transformers.modeling_outputs import SequenceClassifierOutput |
| | from .configuration_captcha import CaptchaConfig |
| |
|
| | class PositionalEncoding(nn.Module): |
| | def __init__(self, d_model, max_len=500): |
| | super().__init__() |
| | pe = torch.zeros(max_len, d_model) |
| | position = torch.arange(0, max_len).unsqueeze(1) |
| | div_term = torch.exp( |
| | torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model) |
| | ) |
| | pe[:, 0::2] = torch.sin(position * div_term) |
| | pe[:, 1::2] = torch.cos(position * div_term) |
| | pe = pe.unsqueeze(0) |
| | self.register_buffer("pe", pe) |
| |
|
| | def forward(self, x): |
| | return x + self.pe[:, : x.size(1)] |
| |
|
| | class CaptchaConvolutionalTransformer(PreTrainedModel): |
| | config_class = CaptchaConfig |
| |
|
| | def __init__(self, config): |
| | super().__init__(config) |
| |
|
| | |
| | self.conv = nn.Sequential( |
| | nn.Conv2d(1, 32, kernel_size=3, padding=1), |
| | nn.BatchNorm2d(32), |
| | nn.SiLU(), |
| | nn.MaxPool2d(2, 2), |
| |
|
| | nn.Conv2d(32, 64, kernel_size=3, padding=1), |
| | nn.BatchNorm2d(64), |
| | nn.SiLU(), |
| | nn.MaxPool2d(2, 2), |
| |
|
| | nn.Conv2d(64, 128, kernel_size=3, padding=1), |
| | nn.BatchNorm2d(128), |
| | nn.SiLU(), |
| | nn.MaxPool2d(kernel_size=(2, 1)), |
| |
|
| | nn.Conv2d(128, 256, kernel_size=3, padding=1), |
| | nn.BatchNorm2d(256), |
| | nn.SiLU(), |
| | ) |
| |
|
| | |
| | self.positional_encoding = PositionalEncoding(config.d_model) |
| |
|
| | |
| | encoder_layer = nn.TransformerEncoderLayer( |
| | d_model=config.d_model, |
| | nhead=config.nhead, |
| | dim_feedforward=config.dim_feedforward, |
| | dropout=config.dropout, |
| | activation="gelu", |
| | batch_first=True, |
| | norm_first=True, |
| | ) |
| |
|
| | self.transformer = nn.TransformerEncoder( |
| | encoder_layer, |
| | num_layers=config.num_layers, |
| | ) |
| |
|
| | |
| | self.classifier = nn.Linear(config.d_model, config.num_chars) |
| |
|
| | |
| | self.post_init() |
| |
|
| | def forward(self, pixel_values, labels=None): |
| | """ |
| | pixel_values: (batch, 1, H, W) |
| | """ |
| | |
| | x = self.conv(pixel_values) |
| |
|
| | |
| | x = x.permute(0, 3, 1, 2) |
| | b, t, c, h = x.size() |
| | |
| | |
| | x = x.reshape(b, t, c * h) |
| |
|
| | |
| | x = self.positional_encoding(x) |
| | x = self.transformer(x) |
| | |
| | |
| | logits = self.classifier(x) |
| |
|
| | |
| | return SequenceClassifierOutput(logits=logits) |