Valerii Sielikhov
Revert "Refactor import statements to remove relative imports in htr_convtext.py, modeling_htr.py, mv_block.py, and resnet18.py"
3529c7b | import warnings | |
| from functools import partial | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from timm.layers import LayerScale | |
| from timm.models.vision_transformer import DropPath | |
| from .resnet18 import ResNet18 | |
| class RelativePositionBias1D(nn.Module): | |
| def __init__(self, num_heads: int, max_rel_positions: int = 1024): | |
| super().__init__() | |
| self.num_heads = num_heads | |
| self.max_rel_positions = max(1, int(max_rel_positions)) | |
| self.bias = nn.Embedding(2 * self.max_rel_positions - 1, num_heads) | |
| nn.init.zeros_(self.bias.weight) | |
| def forward(self, N: int) -> torch.Tensor: | |
| device = self.bias.weight.device | |
| coords = torch.arange(N, device=device) | |
| rel = coords[:, None] - coords[None, :] | |
| rel = rel.clamp(-self.max_rel_positions + 1, self.max_rel_positions - 1) | |
| rel = rel + (self.max_rel_positions - 1) | |
| bias = self.bias(rel) | |
| return bias.permute(2, 0, 1).unsqueeze(0) | |
| class Attention(nn.Module): | |
| def __init__( | |
| self, | |
| dim, | |
| num_patches, | |
| num_heads=8, | |
| qkv_bias=False, | |
| attn_drop=0.0, | |
| proj_drop=0.0, | |
| ): | |
| super().__init__() | |
| assert dim % num_heads == 0, "dim should be divisible by num_heads" | |
| self.num_heads = num_heads | |
| head_dim = dim // num_heads | |
| self.scale = head_dim**-0.5 | |
| max_rel_positions = ( | |
| max(1, int(num_patches)) if num_patches is not None else 1024 | |
| ) | |
| self.rel_pos_bias = RelativePositionBias1D( | |
| num_heads=num_heads, max_rel_positions=max_rel_positions | |
| ) | |
| self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) | |
| self.attn_drop = nn.Dropout(attn_drop) | |
| self.proj = nn.Linear(dim, dim) | |
| self.proj_drop = nn.Dropout(proj_drop) | |
| def forward(self, x): | |
| B, N, C = x.shape | |
| qkv = ( | |
| self.qkv(x) | |
| .reshape(B, N, 3, self.num_heads, C // self.num_heads) | |
| .permute(2, 0, 3, 1, 4) | |
| ) | |
| q, k, v = qkv.unbind(0) | |
| attn = (q @ k.transpose(-2, -1)) * self.scale | |
| attn = attn + self.rel_pos_bias(N) | |
| attn = attn.softmax(dim=-1) | |
| attn = self.attn_drop(attn) | |
| x = (attn @ v).transpose(1, 2).reshape(B, N, C) | |
| x = self.proj(x) | |
| x = self.proj_drop(x) | |
| return x | |
| class FeedForward(nn.Module): | |
| def __init__(self, dim, hidden_dim, dropout=0.1, activation=nn.SiLU): | |
| super().__init__() | |
| self.lin1 = nn.Linear(dim, hidden_dim) | |
| self.act = activation() | |
| self.lin2 = nn.Linear(hidden_dim, dim) | |
| self.dropout = nn.Dropout(dropout) | |
| def forward(self, x): | |
| return self.dropout(self.lin2(self.act(self.lin1(x)))) | |
| class ConvModule(nn.Module): | |
| def __init__( | |
| self, | |
| dim, | |
| kernel_size=3, | |
| dropout=0.1, | |
| drop_path=0.0, | |
| expansion=1.0, | |
| pre_norm=False, | |
| activation=nn.SiLU, | |
| ): | |
| super().__init__() | |
| self.pre_norm = nn.LayerNorm(dim) if pre_norm else None | |
| hidden = int(round(dim * expansion)) | |
| self.pw1 = nn.Conv1d(dim, hidden, kernel_size=1, bias=True) | |
| self.act1 = activation() | |
| self.dw = nn.Conv1d( | |
| hidden, | |
| hidden, | |
| kernel_size=kernel_size, | |
| padding=kernel_size // 2, | |
| groups=hidden, | |
| bias=True, | |
| ) | |
| self.gn = nn.GroupNorm(1, hidden, eps=1e-5) | |
| self.act2 = activation() | |
| self.pw2 = nn.Conv1d(hidden, dim, kernel_size=1, bias=True) | |
| self.dropout = nn.Dropout(dropout) | |
| self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() | |
| def forward(self, x): | |
| if self.pre_norm is not None: | |
| x = self.pre_norm(x) | |
| z = x.transpose(1, 2) | |
| z = self.pw1(z) | |
| z = self.act1(z) | |
| z = self.dw(z) | |
| z = self.gn(z) | |
| z = self.act2(z) | |
| z = self.pw2(z) | |
| z = self.dropout(z).transpose(1, 2) | |
| return self.drop_path(z) | |
| class Downsample1D(nn.Module): | |
| def __init__(self, dim, kernel_size=3, stride=2, lowpass_init=True): | |
| super().__init__() | |
| self.dw = nn.Conv1d( | |
| dim, | |
| dim, | |
| kernel_size=kernel_size, | |
| stride=stride, | |
| padding=kernel_size // 2, | |
| groups=dim, | |
| bias=False, | |
| ) | |
| self.pw = nn.Conv1d(dim, dim, kernel_size=1, bias=True) | |
| if lowpass_init: | |
| with torch.no_grad(): | |
| w = torch.zeros_like(self.dw.weight) | |
| w[:, 0, :] = 1.0 / kernel_size | |
| self.dw.weight.copy_(w) | |
| def forward(self, x): | |
| x = x.transpose(1, 2) | |
| x = self.pw(self.dw(x)) | |
| return x.transpose(1, 2) | |
| class Upsample1D(nn.Module): | |
| def __init__(self, dim, mode: str = "nearest"): | |
| super().__init__() | |
| assert mode in ("nearest", "linear"), ( | |
| "Upsample1D mode must be 'nearest' or 'linear'" | |
| ) | |
| self.mode = mode | |
| self.proj = nn.Conv1d(dim, dim, kernel_size=1, bias=True) | |
| def forward(self, x, target_len: int): | |
| x = x.transpose(1, 2) | |
| if self.mode == "nearest": | |
| x = F.interpolate(x, size=target_len, mode="nearest") | |
| else: | |
| x = F.interpolate(x, size=target_len, mode="linear", align_corners=False) | |
| x = self.proj(x) | |
| return x.transpose(1, 2) | |
| class ConvTextBlock(nn.Module): | |
| def __init__( | |
| self, | |
| dim, | |
| num_heads, | |
| num_patches, | |
| mlp_ratio=4.0, | |
| ff_dropout=0.1, | |
| attn_dropout=0.0, | |
| conv_dropout=0.0, | |
| conv_kernel_size=3, | |
| conv_expansion=1.0, | |
| norm_layer=nn.LayerNorm, | |
| drop_path=0.0, | |
| layerscale_init=1e-5, | |
| ): | |
| super().__init__() | |
| ff_hidden = int(dim * mlp_ratio) | |
| self.attn = Attention( | |
| dim, | |
| num_patches, | |
| num_heads=num_heads, | |
| qkv_bias=True, | |
| attn_drop=attn_dropout, | |
| proj_drop=ff_dropout, | |
| ) | |
| self.ffn1 = FeedForward(dim, ff_hidden, dropout=ff_dropout, activation=nn.SiLU) | |
| self.conv = ConvModule( | |
| dim, | |
| kernel_size=conv_kernel_size, | |
| dropout=conv_dropout, | |
| drop_path=0.0, | |
| expansion=conv_expansion, | |
| pre_norm=False, | |
| activation=nn.SiLU, | |
| ) | |
| self.ffn2 = FeedForward(dim, ff_hidden, dropout=ff_dropout, activation=nn.SiLU) | |
| self.postln_attn = norm_layer(dim, elementwise_affine=True) | |
| self.postln_ffn1 = norm_layer(dim, elementwise_affine=True) | |
| self.postln_conv = norm_layer(dim, elementwise_affine=True) | |
| self.postln_ffn2 = norm_layer(dim, elementwise_affine=True) | |
| self.dp_attn = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() | |
| self.dp_ffn1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() | |
| self.dp_conv = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() | |
| self.dp_ffn2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() | |
| self.ls_attn = LayerScale(dim, init_values=layerscale_init) | |
| self.ls_ffn1 = LayerScale(dim, init_values=layerscale_init) | |
| self.ls_conv = LayerScale(dim, init_values=layerscale_init) | |
| self.ls_ffn2 = LayerScale(dim, init_values=layerscale_init) | |
| def forward(self, x): | |
| x = self.postln_attn(x + self.ls_attn(self.dp_attn(self.attn(x)))) | |
| x = self.postln_ffn1(x + self.ls_ffn1(0.5 * self.dp_ffn1(self.ffn1(x)))) | |
| x = self.postln_conv(x + self.ls_conv(self.dp_conv(self.conv(x)))) | |
| x = self.postln_ffn2(x + self.ls_ffn2(0.5 * self.dp_ffn2(self.ffn2(x)))) | |
| return x | |
| def get_2d_sincos_pos_embed(embed_dim, grid_size): | |
| grid_h = np.arange(grid_size[0], dtype=np.float32) | |
| grid_w = np.arange(grid_size[1], dtype=np.float32) | |
| grid = np.meshgrid(grid_w, grid_h) | |
| grid = np.stack(grid, axis=0) | |
| grid = grid.reshape([2, 1, grid_size[0], grid_size[1]]) | |
| pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) | |
| return pos_embed | |
| def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): | |
| assert embed_dim % 2 == 0 | |
| emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) | |
| emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) | |
| emb = np.concatenate([emb_h, emb_w], axis=1) | |
| return emb | |
| def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): | |
| assert embed_dim % 2 == 0 | |
| omega = np.arange(embed_dim // 2, dtype=np.float64) | |
| omega /= embed_dim / 2.0 | |
| omega = 1.0 / 10000**omega | |
| pos = pos.reshape(-1) | |
| out = np.einsum("m,d->md", pos, omega) | |
| emb_sin = np.sin(out) | |
| emb_cos = np.cos(out) | |
| emb = np.concatenate([emb_sin, emb_cos], axis=1) | |
| return emb | |
| class HTR_ConvText(nn.Module): | |
| def __init__( | |
| self, | |
| nb_cls=80, | |
| img_size=[512, 64], | |
| patch_size=[4, 32], | |
| embed_dim=1024, | |
| depth=24, | |
| num_heads=16, | |
| mlp_ratio=4.0, | |
| norm_layer=nn.LayerNorm, | |
| conv_kernel_size: int = 3, | |
| dropout: float = 0.1, | |
| drop_path: float = 0.1, | |
| down_after: int = 2, | |
| up_after: int = 4, | |
| ds_kernel: int = 3, | |
| max_seq_len: int = 1024, | |
| upsample_mode: str = "nearest", | |
| ): | |
| super().__init__() | |
| self.patch_embed = ResNet18(embed_dim) | |
| self.embed_dim = embed_dim | |
| self.max_rel_pos = int(max_seq_len) | |
| self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) | |
| dpr = np.linspace(0.0, float(drop_path), depth, dtype=float).tolist() | |
| self.blocks = nn.ModuleList( | |
| [ | |
| ConvTextBlock( | |
| embed_dim, | |
| num_heads, | |
| self.max_rel_pos, | |
| mlp_ratio=mlp_ratio, | |
| ff_dropout=dropout, | |
| attn_dropout=dropout, | |
| conv_dropout=dropout, | |
| conv_kernel_size=conv_kernel_size, | |
| conv_expansion=1.0, | |
| norm_layer=norm_layer, | |
| drop_path=dpr[i], | |
| layerscale_init=1e-5, | |
| ) | |
| for i in range(depth) | |
| ] | |
| ) | |
| self.norm = norm_layer(embed_dim, elementwise_affine=True) | |
| self.head = torch.nn.Linear(embed_dim, nb_cls) | |
| self.down_after = down_after | |
| self.up_after = up_after | |
| self.down1 = Downsample1D(embed_dim, kernel_size=ds_kernel) | |
| self.up1 = Upsample1D(embed_dim, mode=upsample_mode) | |
| self.initialize_weights() | |
| def initialize_weights(self): | |
| torch.nn.init.normal_(self.mask_token, std=0.02) | |
| self.apply(self._init_weights) | |
| def _init_weights(self, m): | |
| if isinstance(m, nn.Linear): | |
| torch.nn.init.xavier_uniform_(m.weight) | |
| if m.bias is not None: | |
| nn.init.constant_(m.bias, 0) | |
| elif isinstance(m, nn.LayerNorm): | |
| nn.init.constant_(m.bias, 0) | |
| nn.init.constant_(m.weight, 1.0) | |
| def mask_random_1d(self, x, ratio): | |
| B, L, _ = x.shape | |
| mask = torch.ones(B, L, dtype=torch.bool).to(x.device) | |
| if ratio <= 0.0 or ratio > 1.0: | |
| return mask | |
| num = int(round(ratio * L)) | |
| if num <= 0: | |
| return mask | |
| noise = torch.rand(B, L).to(x.device) | |
| idx = noise.argsort(dim=1)[:, :num] | |
| mask.scatter_(1, idx, False) | |
| return mask | |
| def mask_block_1d(self, x, ratio: float, max_block_length: int): | |
| B, L, _ = x.shape | |
| device = x.device | |
| if ratio <= 0.0: | |
| return torch.ones(B, L, 1, dtype=torch.bool, device=device) | |
| if ratio >= 1.0: | |
| return torch.zeros(B, L, 1, dtype=torch.bool, device=device) | |
| target_mask_tokens = int(round(ratio * L)) | |
| K = target_mask_tokens // max_block_length | |
| K = max(K, 1) | |
| starts = torch.randint( | |
| 0, max(1, L - max_block_length + 1), (B, K), device=device | |
| ) | |
| lengths = torch.randint(1, max_block_length + 1, (B, K), device=device) | |
| positions = torch.arange(L, device=device).view(1, 1, L) | |
| starts_exp = starts.unsqueeze(-1) | |
| ends_exp = (starts + lengths).unsqueeze(-1).clamp(max=L) | |
| blocks_mask = (positions >= starts_exp) & (positions < ends_exp) | |
| masked_any = blocks_mask.any(dim=1) | |
| keep_mask = ~masked_any | |
| return keep_mask.unsqueeze(-1) | |
| def mask_span_1d(self, x, ratio: float, max_span_length: int): | |
| B, L, _ = x.shape | |
| device = x.device | |
| if ratio <= 0.0: | |
| return torch.ones(B, L, 1, dtype=torch.bool, device=device) | |
| if ratio >= 1.0: | |
| return torch.zeros(B, L, 1, dtype=torch.bool, device=device) | |
| target_mask_tokens = int(round(ratio * L)) | |
| K = target_mask_tokens // max_span_length | |
| K = max(K, 1) | |
| starts = torch.randint( | |
| 0, max(1, L - max_span_length + 1), (B, K), device=device | |
| ) | |
| lengths = torch.full((B, K), max_span_length, device=device) | |
| positions = torch.arange(L, device=device).view(1, 1, L) | |
| starts_exp = starts.unsqueeze(-1) | |
| ends_exp = (starts + lengths).unsqueeze(-1).clamp(max=L) | |
| spans_mask = (positions >= starts_exp) & (positions < ends_exp) | |
| masked_any = spans_mask.any(dim=1) | |
| keep_mask = ~masked_any | |
| return keep_mask.unsqueeze(-1) | |
| def forward_features( | |
| self, | |
| x, | |
| use_masking=False, | |
| mask_mode="span", | |
| mask_ratio=0.5, | |
| block_span=4, | |
| max_span_length=8, | |
| ): | |
| x = self.patch_embed(x) | |
| B, C, W, H = x.shape | |
| assert C == self.embed_dim, f"Expected embed_dim {self.embed_dim}, got {C}" | |
| x = x.view(B, C, -1).permute(0, 2, 1) | |
| masked_positions_1d = None | |
| if use_masking: | |
| if mask_mode == "random": | |
| keep_mask_1d = self.mask_random_1d(x, mask_ratio).float() | |
| mask = keep_mask_1d.unsqueeze(-1) | |
| elif mask_mode in ("block"): | |
| keep_mask = self.mask_block_1d(x, mask_ratio, block_span).float() | |
| keep_mask_1d = keep_mask.squeeze(-1) | |
| mask = keep_mask | |
| elif mask_mode in ("span"): | |
| keep_mask = self.mask_span_1d(x, mask_ratio, max_span_length).float() | |
| keep_mask_1d = keep_mask.squeeze(-1) | |
| mask = keep_mask | |
| else: | |
| warnings.warn(f"Unknown mask_mode '{mask_mode}', defaulting to span.") | |
| keep_mask = self.mask_span_1d(x, mask_ratio, max_span_length).float() | |
| keep_mask_1d = keep_mask.squeeze(-1) | |
| mask = keep_mask | |
| masked_positions_1d = (1.0 - keep_mask_1d).clamp(min=0.0, max=1.0) | |
| x = mask * x + (1.0 - mask) * self.mask_token.expand( | |
| x.size(0), x.size(1), x.size(2) | |
| ) | |
| skip_hi = None | |
| for i, blk in enumerate(self.blocks, 1): | |
| x = blk(x) | |
| if i == self.down_after: | |
| skip_hi = x | |
| if (x.size(1) % 2) == 1: | |
| x = torch.cat([x, x[:, -1:, :]], dim=1) | |
| x = self.down1(x) | |
| if i == self.up_after: | |
| assert skip_hi is not None, "Upsample requires a stored skip." | |
| x = self.up1(x, target_len=skip_hi.size(1)) | |
| x = x + skip_hi | |
| x = self.norm(x) | |
| return x, masked_positions_1d | |
| def forward( | |
| self, | |
| x, | |
| use_masking=False, | |
| return_features=False, | |
| return_mask=False, | |
| mask_mode="span", | |
| mask_ratio=None, | |
| block_span=None, | |
| max_span_length=None, | |
| ): | |
| feats, masked_positions_1d = self.forward_features( | |
| x, | |
| use_masking=use_masking, | |
| mask_mode=mask_mode, | |
| mask_ratio=mask_ratio, | |
| block_span=block_span, | |
| max_span_length=max_span_length, | |
| ) | |
| logits = self.head(feats) | |
| if return_features and return_mask: | |
| return ( | |
| logits, | |
| feats, | |
| (masked_positions_1d if masked_positions_1d is not None else None), | |
| ) | |
| if return_features: | |
| return logits, feats | |
| if return_mask: | |
| return logits, ( | |
| masked_positions_1d if masked_positions_1d is not None else None | |
| ) | |
| return logits | |
| def create_model(nb_cls, img_size, mlp_ratio=4, **kwargs): | |
| model = HTR_ConvText( | |
| nb_cls, | |
| img_size=img_size, | |
| patch_size=(4, 64), | |
| embed_dim=512, | |
| depth=8, | |
| num_heads=8, | |
| mlp_ratio=mlp_ratio, | |
| norm_layer=partial(nn.LayerNorm, eps=1e-6), | |
| conv_kernel_size=7, | |
| down_after=3, | |
| up_after=7, | |
| ds_kernel=3, | |
| max_seq_len=128, | |
| upsample_mode="nearest", | |
| **kwargs, | |
| ) | |
| return model | |