Spaces:
Running on Zero
Running on Zero
| from dataclasses import dataclass | |
| from typing import Optional, Tuple | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| from einops import rearrange | |
| from diffusers.utils import BaseOutput, is_torch_version | |
| from diffusers.utils.torch_utils import randn_tensor | |
| from diffusers.models.activations import get_activation | |
| from diffusers.models.attention_processor import SpatialNorm | |
| from diffusers.models.unets.unet_2d_blocks import ( | |
| AutoencoderTinyBlock, | |
| UNetMidBlock2D, | |
| get_down_block, | |
| get_up_block, | |
| ) | |
| from diffusers.models.modeling_utils import ModelMixin | |
| from diffusers.configuration_utils import ConfigMixin, register_to_config | |
| from .vae import Encoder | |
| from .nn_utils import PositionalEncoding1D, A2DPE | |
| # Parameters should be 5.438.209 | |
| # 0 | feature_extractor | Encoder | 1.6 M | |
| # 1 | quant_conv | Conv2d | 16.5 K | |
| # 2 | htr | HTR | 3.9 M | |
| # Parameters: 5.437.695 | |
| class HTR(ModelMixin, ConfigMixin): | |
| def __init__(self, | |
| alphabet_size: int = 169, | |
| in_channels: int = 3, | |
| down_block_types: Tuple[str] = ("DownEncoderBlock2D",), | |
| block_out_channels: Tuple[int] = (64,), | |
| layers_per_block: int = 1, | |
| act_fn: str = "silu", | |
| latent_channels: int = 128, | |
| d_model: int = 128, | |
| norm_num_groups: int = 16, | |
| encoder_dropout: float = 0.1, | |
| use_tgt_pe=True, | |
| use_mem_pe=True, | |
| htr_dropout: float = 0.1, | |
| num_encoder_layers: int = 2, | |
| num_decoder_layers: int = 4, | |
| only_head: bool = False, | |
| ): | |
| super(HTR, self).__init__() | |
| self.only_head = only_head | |
| if not self.only_head: | |
| self.feature_extractor = Encoder( | |
| in_channels=in_channels, | |
| out_channels=latent_channels, | |
| down_block_types=down_block_types, | |
| block_out_channels=block_out_channels, | |
| layers_per_block=layers_per_block, | |
| act_fn=act_fn, | |
| norm_num_groups=norm_num_groups, | |
| double_z=False, | |
| dropout=encoder_dropout, | |
| ) | |
| self.quant_conv = nn.Conv2d(latent_channels, d_model, 1) | |
| # Letter classification | |
| self.text_embedding = nn.Embedding(alphabet_size, d_model) | |
| self.d_model = d_model | |
| self.mem_pe = A2DPE(d_model=d_model, dropout=htr_dropout) if use_mem_pe else None | |
| self.tgt_pe = PositionalEncoding1D(d_model=d_model, dropout=htr_dropout) if use_tgt_pe else None | |
| encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=1) | |
| self.transformer_encoder = nn.TransformerEncoder( | |
| encoder_layer, num_layers=num_encoder_layers, norm=nn.LayerNorm(d_model), enable_nested_tensor=False) | |
| decoder_layer = nn.TransformerDecoderLayer(d_model=d_model, nhead=1) | |
| self.transformer_decoder = nn.TransformerDecoder( | |
| decoder_layer, num_layers=num_decoder_layers, norm=nn.LayerNorm(d_model)) | |
| self.fc = nn.Linear(d_model, alphabet_size) | |
| def forward(self, x, tgt_logits, tgt_mask, tgt_key_padding_mask): | |
| if not self.only_head: | |
| # Feature extraction | |
| memory = self.feature_extractor(x) # [B, 1, 64, 768] -> [B, 128, 8, 96] | |
| else: | |
| memory = x | |
| # if latent htr then input is [B, 1, 8, 96] | |
| memory = self.quant_conv(memory) | |
| # Letter classification | |
| if self.mem_pe is not None: | |
| memory = self.mem_pe(memory) | |
| memory = rearrange(memory, "b c h w -> (h w) b c") | |
| memory = self.transformer_encoder(memory) | |
| tgt = self.text_embedding(tgt_logits) | |
| if self.tgt_pe is not None: | |
| tgt = self.tgt_pe(tgt) | |
| tgt = rearrange(tgt, "b s d -> s b d") | |
| tgt = self.transformer_decoder(tgt, memory, tgt_mask, tgt_key_padding_mask=tgt_key_padding_mask.float()) | |
| tgt = rearrange(tgt, "s b d -> b s d") | |
| tgt = self.fc(tgt) | |
| return tgt | |
| def reset_last_layer(self, alphabet_size): | |
| self.fc = nn.Linear(self.d_model, alphabet_size) | |