| """PyTorch MLE (Mnaga Line Extraction) model""" |
|
|
| from dataclasses import dataclass |
|
|
| import torch |
| import torch.nn as nn |
|
|
| from transformers import PreTrainedModel |
| from transformers.modeling_outputs import ModelOutput, BaseModelOutput |
| from transformers.activations import ACT2FN |
|
|
| from .configuration_mle import MLEConfig |
|
|
|
|
| @dataclass |
| class MLEModelOutput(ModelOutput): |
| last_hidden_state: torch.FloatTensor | None = None |
|
|
|
|
| @dataclass |
| class MLEForAnimeLineExtractionOutput(ModelOutput): |
| last_hidden_state: torch.FloatTensor | None = None |
| pixel_values: torch.Tensor | None = None |
|
|
|
|
| class MLEBatchNorm(nn.Module): |
| def __init__( |
| self, |
| config: MLEConfig, |
| in_features: int, |
| ): |
| super().__init__() |
|
|
| self.norm = nn.BatchNorm2d(in_features, eps=config.batch_norm_eps) |
| |
| if config.hidden_act == "leaky_relu": |
| self.act_fn = nn.LeakyReLU(negative_slope=config.negative_slope) |
| else: |
| self.act_fn = ACT2FN[config.hidden_act] |
|
|
| def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
| hidden_states = self.norm(hidden_states) |
| hidden_states = self.act_fn(hidden_states) |
|
|
| return hidden_states |
|
|
|
|
| class MLEResBlock(nn.Module): |
| def __init__( |
| self, |
| config: MLEConfig, |
| in_channels: int, |
| out_channels: int, |
| stride_size: int, |
| ): |
| super().__init__() |
|
|
| self.norm1 = MLEBatchNorm(config, in_channels) |
| self.conv1 = nn.Conv2d( |
| in_channels, |
| out_channels, |
| config.block_kernel_size, |
| stride=stride_size, |
| padding=config.block_kernel_size // 2, |
| ) |
|
|
| self.norm2 = MLEBatchNorm(config, out_channels) |
| self.conv2 = nn.Conv2d( |
| out_channels, |
| out_channels, |
| config.block_kernel_size, |
| stride=1, |
| padding=config.block_kernel_size // 2, |
| ) |
|
|
| if in_channels != out_channels or stride_size != 1: |
| self.resize = nn.Conv2d( |
| in_channels, |
| out_channels, |
| kernel_size=1, |
| stride=stride_size, |
| ) |
| else: |
| self.resize = None |
|
|
| def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
| output = self.norm1(hidden_states) |
| output = self.conv1(output) |
| output = self.norm2(output) |
| output = self.conv2(output) |
|
|
| if self.resize is not None: |
| resized_input = self.resize(hidden_states) |
| output += resized_input |
| else: |
| output += hidden_states |
|
|
| return output |
|
|
|
|
| class MLEEncoderLayer(nn.Module): |
| def __init__( |
| self, |
| config: MLEConfig, |
| in_features: int, |
| out_features: int, |
| num_layers: int, |
| stride_sizes: list[int], |
| ): |
| super().__init__() |
|
|
| self.blocks = nn.ModuleList( |
| [ |
| MLEResBlock( |
| config, |
| in_channels=in_features if i == 0 else out_features, |
| out_channels=out_features, |
| stride_size=stride_sizes[i], |
| ) |
| for i in range(num_layers) |
| ] |
| ) |
|
|
| def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
| for block in self.blocks: |
| hidden_states = block(hidden_states) |
| return hidden_states |
|
|
|
|
| class MLEEncoder(nn.Module): |
| def __init__( |
| self, |
| config: MLEConfig, |
| ): |
| super().__init__() |
|
|
| self.layers = nn.ModuleList( |
| [ |
| MLEEncoderLayer( |
| config, |
| in_features=( |
| config.in_channels |
| if i == 0 |
| else config.in_channels |
| * config.block_patch_size |
| * (config.upsample_ratio ** (i - 1)) |
| ), |
| out_features=config.in_channels |
| * config.block_patch_size |
| * (config.upsample_ratio**i), |
| num_layers=num_layers, |
| stride_sizes=( |
| [ |
| 1 if i_layer < num_layers - 1 else 2 |
| for i_layer in range(num_layers) |
| ] |
| if i > 0 |
| else [1 for _ in range(num_layers)] |
| ), |
| ) |
| for i, num_layers in enumerate(config.num_encoder_layers) |
| ] |
| ) |
|
|
| def forward( |
| self, hidden_states: torch.Tensor |
| ) -> tuple[torch.Tensor, tuple[torch.Tensor, ...]]: |
| all_hidden_states: tuple[torch.Tensor, ...] = () |
| for layer in self.layers: |
| hidden_states = layer(hidden_states) |
| all_hidden_states += (hidden_states,) |
| return hidden_states, all_hidden_states |
|
|
|
|
| class MLEUpsampleBlock(nn.Module): |
| def __init__(self, config: MLEConfig, in_features: int, out_features: int): |
| super().__init__() |
|
|
| self.norm = MLEBatchNorm(config, in_features=in_features) |
| self.conv = nn.Conv2d( |
| in_features, |
| out_features, |
| config.block_kernel_size, |
| stride=1, |
| padding=config.block_kernel_size // 2, |
| ) |
| self.upsample = nn.Upsample(scale_factor=config.upsample_ratio) |
|
|
| def forward(self, hidden_states: torch.Tensor): |
| output = self.norm(hidden_states) |
| output = self.conv(output) |
| output = self.upsample(output) |
|
|
| return output |
|
|
|
|
| class MLEUpsampleResBlock(nn.Module): |
| def __init__(self, config: MLEConfig, in_features: int, out_features: int): |
| super().__init__() |
|
|
| self.upsample = MLEUpsampleBlock( |
| config, in_features=in_features, out_features=out_features |
| ) |
|
|
| self.norm = MLEBatchNorm(config, in_features=out_features) |
| self.conv = nn.Conv2d( |
| out_features, |
| out_features, |
| config.block_kernel_size, |
| stride=1, |
| padding=config.block_kernel_size // 2, |
| ) |
|
|
| if in_features != out_features: |
| self.resize = nn.Sequential( |
| nn.Conv2d( |
| in_features, |
| out_features, |
| kernel_size=1, |
| stride=1, |
| ), |
| nn.Upsample(scale_factor=config.upsample_ratio), |
| ) |
| else: |
| self.resize = None |
|
|
| def forward(self, hidden_states: torch.Tensor): |
| output = self.upsample(hidden_states) |
| output = self.norm(output) |
| output = self.conv(output) |
|
|
| if self.resize is not None: |
| output += self.resize(hidden_states) |
|
|
| return output |
|
|
|
|
| class MLEDecoderLayer(nn.Module): |
| def __init__( |
| self, |
| config: MLEConfig, |
| in_features: int, |
| out_features: int, |
| num_layers: int, |
| ): |
| super().__init__() |
|
|
| self.blocks = nn.ModuleList( |
| [ |
| ( |
| MLEResBlock( |
| config, |
| in_channels=out_features, |
| out_channels=out_features, |
| stride_size=1, |
| ) |
| if i > 0 |
| else MLEUpsampleResBlock( |
| config, |
| in_features=in_features, |
| out_features=out_features, |
| ) |
| ) |
| for i in range(num_layers) |
| ] |
| ) |
|
|
| def forward( |
| self, hidden_states: torch.Tensor, shortcut_states: torch.Tensor |
| ) -> torch.Tensor: |
| for block in self.blocks: |
| hidden_states = block(hidden_states) |
|
|
| hidden_states += shortcut_states |
|
|
| return hidden_states |
|
|
|
|
| class MLEDecoderHead(nn.Module): |
| def __init__(self, config: MLEConfig, num_layers: int): |
| super().__init__() |
|
|
| self.layer = MLEEncoderLayer( |
| config, |
| in_features=config.block_patch_size, |
| out_features=config.last_hidden_channels, |
| stride_sizes=[1 for _ in range(num_layers)], |
| num_layers=num_layers, |
| ) |
| self.norm = MLEBatchNorm(config, in_features=config.last_hidden_channels) |
| self.conv = nn.Conv2d( |
| config.last_hidden_channels, |
| out_channels=1, |
| kernel_size=1, |
| stride=1, |
| ) |
|
|
| def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
| hidden_states = self.layer(hidden_states) |
| hidden_states = self.norm(hidden_states) |
| pixel_values = self.conv(hidden_states) |
| return pixel_values |
|
|
|
|
| class MLEDecoder(nn.Module): |
| def __init__( |
| self, |
| config: MLEConfig, |
| ): |
| super().__init__() |
|
|
| encoder_output_channels = ( |
| config.in_channels |
| * config.block_patch_size |
| * (config.upsample_ratio ** (len(config.num_encoder_layers) - 1)) |
| ) |
| upsample_ratio = config.upsample_ratio |
| num_decoder_layers = config.num_decoder_layers |
|
|
| self.layers = nn.ModuleList( |
| [ |
| ( |
| MLEDecoderLayer( |
| config, |
| in_features=encoder_output_channels // (upsample_ratio**i), |
| out_features=encoder_output_channels |
| // (upsample_ratio ** (i + 1)), |
| num_layers=num_layers, |
| ) |
| if i < len(num_decoder_layers) - 1 |
| else MLEDecoderHead( |
| config, |
| num_layers=num_layers, |
| ) |
| ) |
| for i, num_layers in enumerate(num_decoder_layers) |
| ] |
| ) |
|
|
| def forward( |
| self, |
| last_hidden_states: torch.Tensor, |
| encoder_hidden_states: tuple[torch.Tensor, ...], |
| ) -> torch.Tensor: |
| hidden_states = last_hidden_states |
| num_encoder_hidden_states = len(encoder_hidden_states) |
|
|
| for i, layer in enumerate(self.layers): |
| if i < len(self.layers) - 1: |
| hidden_states = layer( |
| hidden_states, |
| |
| |
| |
| encoder_hidden_states[num_encoder_hidden_states - 2 - i], |
| ) |
| else: |
| |
| hidden_states = layer(hidden_states) |
|
|
| return hidden_states |
|
|
|
|
| class MLEPretrainedModel(PreTrainedModel): |
| config_class = MLEConfig |
| base_model_prefix = "model" |
| supports_gradient_checkpointing = True |
|
|
|
|
| class MLEModel(MLEPretrainedModel): |
| def __init__(self, config: MLEConfig): |
| super().__init__(config) |
| self.config = config |
|
|
| self.encoder = MLEEncoder(config) |
| self.decoder = MLEDecoder(config) |
|
|
| |
| self.post_init() |
|
|
| def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: |
| encoder_output, all_hidden_states = self.encoder(pixel_values) |
| decoder_output = self.decoder(encoder_output, all_hidden_states) |
|
|
| return decoder_output |
|
|
|
|
| class MLEForAnimeLineExtraction(MLEPretrainedModel): |
| def __init__(self, config: MLEConfig): |
| super().__init__(config) |
|
|
| self.model = MLEModel(config) |
|
|
| def postprocess(self, output_tensor: torch.Tensor, input_shape: tuple[int, int]): |
| pixel_values = output_tensor[:, 0, :, :] |
| pixel_values = torch.clip(pixel_values, 0, 255) |
|
|
| pixel_values = pixel_values[:, 0 : input_shape[0], 0 : input_shape[1]] |
| return pixel_values |
|
|
| def forward( |
| self, pixel_values: torch.Tensor, return_dict: bool = True |
| ) -> tuple[torch.Tensor, ...] | MLEForAnimeLineExtractionOutput: |
| |
| input_image_size = (pixel_values.shape[2], pixel_values.shape[3]) |
|
|
| model_output = self.model(pixel_values) |
|
|
| if not return_dict: |
| return (model_output, self.postprocess(model_output, input_image_size)) |
|
|
| else: |
| return MLEForAnimeLineExtractionOutput( |
| last_hidden_state=model_output, |
| pixel_values=self.postprocess(model_output, input_image_size), |
| ) |
|
|