| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from timm.models.vision_transformer import Mlp, DropPath | |
| from timm.layers import LayerScale | |
| import numpy as np | |
| from model import resnet18 | |
| from functools import partial | |
| import random | |
| import re | |
| import warnings | |
| class RelativePositionBias1D(nn.Module): | |
| def __init__(self, num_heads: int, max_rel_positions: int = 1024): | |
| super().__init__() | |
| self.num_heads = num_heads | |
| self.max_rel_positions = max(1, int(max_rel_positions)) | |
| self.bias = nn.Embedding(2 * self.max_rel_positions - 1, num_heads) | |
| nn.init.zeros_(self.bias.weight) | |
| def forward(self, N: int) -> torch.Tensor: | |
| device = self.bias.weight.device | |
| coords = torch.arange(N, device=device) | |
| rel = coords[:, None] - coords[None, :] | |
| rel = rel.clamp(-self.max_rel_positions + 1, | |
| self.max_rel_positions - 1) | |
| rel = rel + (self.max_rel_positions - 1) | |
| bias = self.bias(rel) | |
| return bias.permute(2, 0, 1).unsqueeze(0) | |
| class Attention(nn.Module): | |
| def __init__(self, dim, num_patches, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.): | |
| super().__init__() | |
| assert dim % num_heads == 0, 'dim should be divisible by num_heads' | |
| self.num_heads = num_heads | |
| head_dim = dim // num_heads | |
| self.scale = head_dim ** -0.5 | |
| max_rel_positions = max( | |
| 1, int(num_patches)) if num_patches is not None else 1024 | |
| self.rel_pos_bias = RelativePositionBias1D( | |
| num_heads=num_heads, max_rel_positions=max_rel_positions) | |
| self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) | |
| self.attn_drop = nn.Dropout(attn_drop) | |
| self.proj = nn.Linear(dim, dim) | |
| self.proj_drop = nn.Dropout(proj_drop) | |
| def forward(self, x): | |
| B, N, C = x.shape | |
| qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // | |
| self.num_heads).permute(2, 0, 3, 1, 4) | |
| q, k, v = qkv.unbind(0) | |
| attn = (q @ k.transpose(-2, -1)) * self.scale | |
| attn = attn + self.rel_pos_bias(N) | |
| attn = attn.softmax(dim=-1) | |
| attn = self.attn_drop(attn) | |
| x = (attn @ v).transpose(1, 2).reshape(B, N, C) | |
| x = self.proj(x) | |
| x = self.proj_drop(x) | |
| return x | |
| class FeedForward(nn.Module): | |
| def __init__(self, dim, hidden_dim, dropout=0.1, activation=nn.SiLU): | |
| super().__init__() | |
| self.lin1 = nn.Linear(dim, hidden_dim) | |
| self.act = activation() | |
| self.lin2 = nn.Linear(hidden_dim, dim) | |
| self.dropout = nn.Dropout(dropout) | |
| def forward(self, x): | |
| return self.dropout(self.lin2(self.act(self.lin1(x)))) | |
| class ConvModule(nn.Module): | |
| def __init__(self, dim, kernel_size=3, dropout=0.1, drop_path=0.0, | |
| expansion=1.0, pre_norm=False, activation=nn.SiLU): | |
| super().__init__() | |
| self.pre_norm = nn.LayerNorm(dim) if pre_norm else None | |
| hidden = int(round(dim * expansion)) | |
| self.pw1 = nn.Conv1d(dim, hidden, kernel_size=1, bias=True) | |
| self.act1 = activation() | |
| self.dw = nn.Conv1d(hidden, hidden, kernel_size=kernel_size, | |
| padding=kernel_size // 2, groups=hidden, bias=True) | |
| self.gn = nn.GroupNorm(1, hidden, eps=1e-5) | |
| self.act2 = activation() | |
| self.pw2 = nn.Conv1d(hidden, dim, kernel_size=1, bias=True) | |
| self.dropout = nn.Dropout(dropout) | |
| self.drop_path = DropPath( | |
| drop_path) if drop_path > 0.0 else nn.Identity() | |
| def forward(self, x): | |
| if self.pre_norm is not None: | |
| x = self.pre_norm(x) | |
| z = x.transpose(1, 2) | |
| z = self.pw1(z) | |
| z = self.act1(z) | |
| z = self.dw(z) | |
| z = self.gn(z) | |
| z = self.act2(z) | |
| z = self.pw2(z) | |
| z = self.dropout(z).transpose(1, 2) | |
| return self.drop_path(z) | |
| class Downsample1D(nn.Module): | |
| def __init__(self, dim, kernel_size=3, stride=2, lowpass_init=True): | |
| super().__init__() | |
| self.dw = nn.Conv1d(dim, dim, kernel_size=kernel_size, | |
| stride=stride, padding=kernel_size//2, | |
| groups=dim, bias=False) | |
| self.pw = nn.Conv1d(dim, dim, kernel_size=1, bias=True) | |
| if lowpass_init: | |
| with torch.no_grad(): | |
| w = torch.zeros_like(self.dw.weight) | |
| w[:, 0, :] = 1.0 / kernel_size | |
| self.dw.weight.copy_(w) | |
| def forward(self, x): | |
| x = x.transpose(1, 2) | |
| x = self.pw(self.dw(x)) | |
| return x.transpose(1, 2) | |
| class Upsample1D(nn.Module): | |
| def __init__(self, dim, mode: str = 'nearest'): | |
| super().__init__() | |
| assert mode in ( | |
| 'nearest', 'linear'), "Upsample1D mode must be 'nearest' or 'linear'" | |
| self.mode = mode | |
| self.proj = nn.Conv1d(dim, dim, kernel_size=1, bias=True) | |
| def forward(self, x, target_len: int): | |
| x = x.transpose(1, 2) | |
| if self.mode == 'nearest': | |
| x = F.interpolate(x, size=target_len, mode='nearest') | |
| else: | |
| x = F.interpolate(x, size=target_len, | |
| mode='linear', align_corners=False) | |
| x = self.proj(x) | |
| return x.transpose(1, 2) | |
| class ConvTextBlock(nn.Module): | |
| def __init__(self, | |
| dim, | |
| num_heads, | |
| num_patches, | |
| mlp_ratio=4.0, | |
| ff_dropout=0.1, | |
| attn_dropout=0.0, | |
| conv_dropout=0.0, | |
| conv_kernel_size=3, | |
| conv_expansion=1.0, | |
| norm_layer=nn.LayerNorm, | |
| drop_path=0.0, | |
| layerscale_init=1e-5): | |
| super().__init__() | |
| ff_hidden = int(dim * mlp_ratio) | |
| self.attn = Attention(dim, num_patches, num_heads=num_heads, | |
| qkv_bias=True, attn_drop=attn_dropout, proj_drop=ff_dropout) | |
| self.ffn1 = FeedForward( | |
| dim, ff_hidden, dropout=ff_dropout, activation=nn.SiLU) | |
| self.conv = ConvModule(dim, kernel_size=conv_kernel_size, | |
| dropout=conv_dropout, drop_path=0.0, | |
| expansion=conv_expansion, pre_norm=False, activation=nn.SiLU) | |
| self.ffn2 = FeedForward( | |
| dim, ff_hidden, dropout=ff_dropout, activation=nn.SiLU) | |
| self.postln_attn = norm_layer(dim, elementwise_affine=True) | |
| self.postln_ffn1 = norm_layer(dim, elementwise_affine=True) | |
| self.postln_conv = norm_layer(dim, elementwise_affine=True) | |
| self.postln_ffn2 = norm_layer(dim, elementwise_affine=True) | |
| self.dp_attn = DropPath( | |
| drop_path) if drop_path > 0.0 else nn.Identity() | |
| self.dp_ffn1 = DropPath( | |
| drop_path) if drop_path > 0.0 else nn.Identity() | |
| self.dp_conv = DropPath( | |
| drop_path) if drop_path > 0.0 else nn.Identity() | |
| self.dp_ffn2 = DropPath( | |
| drop_path) if drop_path > 0.0 else nn.Identity() | |
| self.ls_attn = LayerScale(dim, init_values=layerscale_init) | |
| self.ls_ffn1 = LayerScale(dim, init_values=layerscale_init) | |
| self.ls_conv = LayerScale(dim, init_values=layerscale_init) | |
| self.ls_ffn2 = LayerScale(dim, init_values=layerscale_init) | |
| def forward(self, x): | |
| x = self.postln_attn(x + self.ls_attn(self.dp_attn(self.attn(x)))) | |
| x = self.postln_ffn1( | |
| x + self.ls_ffn1(0.5 * self.dp_ffn1(self.ffn1(x)))) | |
| x = self.postln_conv(x + self.ls_conv(self.dp_conv(self.conv(x)))) | |
| x = self.postln_ffn2( | |
| x + self.ls_ffn2(0.5 * self.dp_ffn2(self.ffn2(x)))) | |
| return x | |
| def get_2d_sincos_pos_embed(embed_dim, grid_size): | |
| grid_h = np.arange(grid_size[0], dtype=np.float32) | |
| grid_w = np.arange(grid_size[1], dtype=np.float32) | |
| grid = np.meshgrid(grid_w, grid_h) | |
| grid = np.stack(grid, axis=0) | |
| grid = grid.reshape([2, 1, grid_size[0], grid_size[1]]) | |
| pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) | |
| return pos_embed | |
| def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): | |
| assert embed_dim % 2 == 0 | |
| emb_h = get_1d_sincos_pos_embed_from_grid( | |
| embed_dim // 2, grid[0]) | |
| emb_w = get_1d_sincos_pos_embed_from_grid( | |
| embed_dim // 2, grid[1]) | |
| emb = np.concatenate([emb_h, emb_w], axis=1) | |
| return emb | |
| def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): | |
| assert embed_dim % 2 == 0 | |
| omega = np.arange(embed_dim // 2, dtype=np.float64) | |
| omega /= embed_dim / 2. | |
| omega = 1. / 10000 ** omega | |
| pos = pos.reshape(-1) | |
| out = np.einsum('m,d->md', pos, omega) | |
| emb_sin = np.sin(out) | |
| emb_cos = np.cos(out) | |
| emb = np.concatenate([emb_sin, emb_cos], axis=1) | |
| return emb | |
| class HTR_ConvText(nn.Module): | |
| def __init__( | |
| self, | |
| nb_cls=80, | |
| img_size=[512, 64], | |
| patch_size=[4, 32], | |
| embed_dim=1024, | |
| depth=24, | |
| num_heads=16, | |
| mlp_ratio=4.0, | |
| norm_layer=nn.LayerNorm, | |
| conv_kernel_size: int = 3, | |
| dropout: float = 0.1, | |
| drop_path: float = 0.1, | |
| down_after: int = 2, | |
| up_after: int = 4, | |
| ds_kernel: int = 3, | |
| max_seq_len: int = 1024, | |
| upsample_mode: str = 'nearest', | |
| ): | |
| super().__init__() | |
| self.patch_embed = resnet18.ResNet18(embed_dim) | |
| self.embed_dim = embed_dim | |
| self.max_rel_pos = int(max_seq_len) | |
| self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) | |
| dpr = [x.item() for x in torch.linspace(0, drop_path, depth)] | |
| self.blocks = nn.ModuleList([ | |
| ConvTextBlock(embed_dim, num_heads, self.max_rel_pos, | |
| mlp_ratio=mlp_ratio, | |
| ff_dropout=dropout, attn_dropout=dropout, | |
| conv_dropout=dropout, conv_kernel_size=conv_kernel_size, | |
| conv_expansion=1.0, | |
| norm_layer=norm_layer, drop_path=dpr[i], | |
| layerscale_init=1e-5) | |
| for i in range(depth) | |
| ]) | |
| self.norm = norm_layer(embed_dim, elementwise_affine=True) | |
| self.head = torch.nn.Linear(embed_dim, nb_cls) | |
| self.down_after = down_after | |
| self.up_after = up_after | |
| self.down1 = Downsample1D(embed_dim, kernel_size=ds_kernel) | |
| self.up1 = Upsample1D(embed_dim, mode=upsample_mode) | |
| self.initialize_weights() | |
| def initialize_weights(self): | |
| torch.nn.init.normal_(self.mask_token, std=.02) | |
| self.apply(self._init_weights) | |
| def _init_weights(self, m): | |
| if isinstance(m, nn.Linear): | |
| torch.nn.init.xavier_uniform_(m.weight) | |
| if m.bias is not None: | |
| nn.init.constant_(m.bias, 0) | |
| elif isinstance(m, nn.LayerNorm): | |
| nn.init.constant_(m.bias, 0) | |
| nn.init.constant_(m.weight, 1.0) | |
| def mask_random_1d(self, x, ratio): | |
| B, L, _ = x.shape | |
| mask = torch.ones(B, L, dtype=torch.bool).to(x.device) | |
| if ratio <= 0.0 or ratio > 1.0: | |
| return mask | |
| num = int(round(ratio * L)) | |
| if num <= 0: | |
| return mask | |
| noise = torch.rand(B, L).to(x.device) | |
| idx = noise.argsort(dim=1)[:, :num] | |
| mask.scatter_(1, idx, False) | |
| return mask | |
| def mask_block_1d(self, x, ratio: float, max_block_length: int): | |
| B, L, _ = x.shape | |
| device = x.device | |
| if ratio <= 0.0: | |
| return torch.ones(B, L, 1, dtype=torch.bool, device=device) | |
| if ratio >= 1.0: | |
| return torch.zeros(B, L, 1, dtype=torch.bool, device=device) | |
| target_mask_tokens = int(round(ratio * L)) | |
| K = target_mask_tokens // max_block_length | |
| K = max(K, 1) | |
| starts = torch.randint(0, max(1, L - max_block_length + 1), (B, K), device=device) | |
| lengths = torch.randint(1, max_block_length + 1, (B, K), device=device) | |
| positions = torch.arange(L, device=device).view(1, 1, L) | |
| starts_exp = starts.unsqueeze(-1) | |
| ends_exp = (starts + lengths).unsqueeze(-1).clamp(max=L) | |
| blocks_mask = (positions >= starts_exp) & (positions < ends_exp) | |
| masked_any = blocks_mask.any(dim=1) | |
| keep_mask = ~masked_any | |
| return keep_mask.unsqueeze(-1) | |
| def mask_span_1d(self, x, ratio: float, max_span_length: int): | |
| B, L, _ = x.shape | |
| device = x.device | |
| if ratio <= 0.0: | |
| return torch.ones(B, L, 1, dtype=torch.bool, device=device) | |
| if ratio >= 1.0: | |
| return torch.zeros(B, L, 1, dtype=torch.bool, device=device) | |
| target_mask_tokens = int(round(ratio * L)) | |
| K = target_mask_tokens // max_span_length | |
| K = max(K, 1) | |
| starts = torch.randint(0, max(1, L - max_span_length + 1), (B, K), device=device) | |
| lengths = torch.full((B, K), max_span_length, device=device) | |
| positions = torch.arange(L, device=device).view(1, 1, L) | |
| starts_exp = starts.unsqueeze(-1) | |
| ends_exp = (starts + lengths).unsqueeze(-1).clamp(max=L) | |
| spans_mask = (positions >= starts_exp) & (positions < ends_exp) | |
| masked_any = spans_mask.any(dim=1) | |
| keep_mask = ~masked_any | |
| return keep_mask.unsqueeze(-1) | |
| def forward_features(self, x, use_masking=False, | |
| mask_mode="span", | |
| mask_ratio=0.5, block_span=4, max_span_length=8): | |
| x = self.patch_embed(x) | |
| B, C, W, H = x.shape | |
| assert C == self.embed_dim, f"Expected embed_dim {self.embed_dim}, got {C}" | |
| x = x.view(B, C, -1).permute(0, 2, 1) | |
| masked_positions_1d = None | |
| if use_masking: | |
| if mask_mode == "random": | |
| keep_mask_1d = self.mask_random_1d(x, mask_ratio).float() | |
| mask = keep_mask_1d.unsqueeze(-1) | |
| elif mask_mode in ("block"): | |
| keep_mask = self.mask_block_1d(x, mask_ratio, block_span).float() | |
| keep_mask_1d = keep_mask.squeeze(-1) | |
| mask = keep_mask | |
| elif mask_mode in ("span"): | |
| keep_mask = self.mask_span_1d( | |
| x, mask_ratio, max_span_length).float() | |
| keep_mask_1d = keep_mask.squeeze(-1) | |
| mask = keep_mask | |
| else: | |
| warnings.warn( | |
| f"Unknown mask_mode '{mask_mode}', defaulting to span.") | |
| keep_mask = self.mask_span_1d( | |
| x, mask_ratio, max_span_length).float() | |
| keep_mask_1d = keep_mask.squeeze(-1) | |
| mask = keep_mask | |
| masked_positions_1d = (1.0 - keep_mask_1d).clamp(min=0.0, max=1.0) | |
| x = mask * x + (1.0 - mask) * \ | |
| self.mask_token.expand(x.size(0), x.size(1), x.size(2)) | |
| skip_hi = None | |
| for i, blk in enumerate(self.blocks, 1): | |
| x = blk(x) | |
| if i == self.down_after: | |
| skip_hi = x | |
| if (x.size(1) % 2) == 1: | |
| x = torch.cat([x, x[:, -1:, :]], dim=1) | |
| x = self.down1(x) | |
| if i == self.up_after: | |
| assert skip_hi is not None, "Upsample requires a stored skip." | |
| x = self.up1(x, target_len=skip_hi.size(1)) | |
| x = x + skip_hi | |
| x = self.norm(x) | |
| return x, masked_positions_1d | |
| def forward(self, x, use_masking=False, return_features=False, return_mask=False, | |
| mask_mode="span", mask_ratio=None, block_span=None, max_span_length=None): | |
| feats, masked_positions_1d = self.forward_features( | |
| x, use_masking=use_masking, mask_mode=mask_mode, mask_ratio=mask_ratio, block_span=block_span, max_span_length=max_span_length) | |
| logits = self.head(feats) | |
| if return_features and return_mask: | |
| return logits, feats, (masked_positions_1d if masked_positions_1d is not None else None) | |
| if return_features: | |
| return logits, feats | |
| if return_mask: | |
| return logits, (masked_positions_1d if masked_positions_1d is not None else None) | |
| return logits | |
| def create_model(nb_cls, img_size, mlp_ratio=4, **kwargs): | |
| model = HTR_ConvText( | |
| nb_cls, | |
| img_size=img_size, | |
| patch_size=(4, 64), | |
| embed_dim=512, | |
| depth=8, | |
| num_heads=8, | |
| mlp_ratio=mlp_ratio, | |
| norm_layer=partial(nn.LayerNorm, eps=1e-6), | |
| conv_kernel_size=7, | |
| down_after=3, | |
| up_after=7, | |
| ds_kernel=3, | |
| max_seq_len=128, | |
| upsample_mode='nearest', | |
| **kwargs, | |
| ) | |
| return model | |