""" Patch-to-Patch Vision Transformer Models ---------------------------------------- This module implements two variants of Vision Transformer (ViT) architectures for dense regression tasks, designed around a patch-to-patch learning paradigm. Both models decompose images into non-overlapping patches, process them through a transformer encoder, and then reconstruct the output image. Components: - Patchify: Splits an input image into flattened non-overlapping patches. - Unpatchify: Reconstructs an image tensor from a sequence of flattened patches. - ViT_Patch2Patch (Version 1): • Patchify + linear projection into embedding space. • Sinusoidal positional encoding. • Transformer encoder with configurable depth/heads. • Linear decoder back to patch space, then Unpatchify. - ViT_Patch2Patch_ver2 (Version 2): • Patch embedding via Conv2d (stride = patch size). • Learned positional embeddings with dropout. • Transformer encoder with configurable depth/heads. • CNN-based decoder with PixelShuffle layers for super-resolution-style upsampling back to the original image resolution. Utilities: - test_model: Simple wrapper to run a forward pass and log output shapes. - main: Runs lightweight tests of both model variants on a dummy input tensor. Usage Example: >>> import torch >>> from vit_patch2patch import ViT_Patch2Patch, ViT_Patch2Patch_ver2 >>> model = ViT_Patch2Patch(img_size=512, patch_size=8, in_ch=3, out_ch=3) >>> dummy = torch.randn(1, 3, 512, 512) >>> out = model(dummy) >>> print(out.shape) # torch.Size([1, 3, 512, 512]) Notes: - Logging is used to track initialization parameters and parameter counts. - Default settings assume square images (H = W = img_size). - PixelShuffle decoder in Version 2 assumes patch_size divisible by upscaling factors. """ import logging import torch from torch import Tensor import torch.nn as nn # import timm ############################################################################################################## class PositionalEncoding(nn.Module): def __init__(self, emb_size: int, max_len: int = 1000): """ Sinusoidal Positional Encoding Module. Args: emb_size (int): The size of the embedding dimension. max_len (int): The maximum length of the sequence. """ super(PositionalEncoding, self).__init__() self.logger = logging.getLogger(self.__class__.__name__) self.logger.info( f"Initializing PositionalEncoding with emb_size={emb_size}, max_len={max_len}" ) position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) div_term = torch.exp( torch.arange(0, emb_size, 2).float() * (-torch.log(10000.0) / emb_size) ) pe = torch.zeros(max_len, emb_size) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) self.register_buffer("positional_encoding", pe.unsqueeze(0)) def forward(self, x: Tensor) -> Tensor: """ Add positional encoding to the input tensor. Args: x (Tensor): Input tensor of shape (batch_size, seq_len, emb_size). Returns: Tensor: Positional encoded tensor of the same shape as input. """ seq_len = x.size(1) self.logger.debug(f"Adding positional encoding to tensor of shape {x.shape}") # Ensure positional encoding is on the same device as x self.positional_encoding = self.positional_encoding.to(x.device) return x + self.positional_encoding[:, :seq_len, :] ############################################################################################################## class Patchify(nn.Module): def __init__(self, patch_size): super().__init__() self.patch_size = patch_size self.logger = logging.getLogger(self.__class__.__name__) self.logger.info( f"Initializing {self.__class__.__name__} | patch_size = {self.patch_size}" ) def forward(self, x): # (B, C, H, W) B, C, H, W = x.shape p = self.patch_size assert H % p == 0 and W % p == 0 x = x.unfold(2, p, p).unfold(3, p, p) # B, C, H//p, W//p, p, p x = x.permute(0, 2, 3, 1, 4, 5).flatten(1, 3) # B, N, C, p, p return x.reshape(B, -1, C * p * p) # B, N, patch_dim class Unpatchify(nn.Module): def __init__(self, patch_size, out_channels, image_size): super().__init__() self.patch_size = patch_size self.out_channels = out_channels self.image_size = image_size self.logger = logging.getLogger(self.__class__.__name__) self.logger.info( f"Initializing {self.__class__.__name__} | patch_size = {self.patch_size} | out_channels={out_channels} | image_size={image_size}" ) def forward(self, x): # (B, N, patch_dim) B, N, D = x.shape p = self.patch_size H, W = self.image_size C = self.out_channels x = x.reshape(B, H // p, W // p, C, p, p) x = x.permute(0, 3, 1, 4, 2, 5).reshape(B, C, H, W) return x class ViT_Patch2Patch(nn.Module): def __init__( self, img_size=512, patch_size=8, in_ch=3, out_ch=3, embed_dim=512, depth=6, heads=8, ): super().__init__() self.patch_size = patch_size self.img_size = img_size self.num_patches = (img_size // patch_size) ** 2 self.patch_dim = in_ch * patch_size * patch_size self.output_dim = out_ch * patch_size * patch_size self.logger = logging.getLogger(self.__class__.__name__) self.logger.info( f"Initalizing {self.__class__.__name__} | img_size={img_size} | patch_size={patch_size}" "| in_ch={in_ch} | out_ch={out_ch} | embed_dim={embed_dim} | depth={depth} | heads={heads}" ) # Modules self.patchify = Patchify(patch_size) self.proj = nn.Linear(self.patch_dim, embed_dim) self.pos_encoding = PositionalEncoding( emb_size=embed_dim, max_len=self.num_patches ) encoder_layer = nn.TransformerEncoderLayer( embed_dim, heads, dim_feedforward=embed_dim * 4, batch_first=True ) self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=depth) self.decoder = nn.Linear(embed_dim, self.output_dim) self.unpatchify = Unpatchify(patch_size, out_ch, (img_size, img_size)) self._log_parameter_count() def _log_parameter_count(self): """ Logs total and trainable parameters in the model, summarized by top-level modules. """ self.logger.info( f"{self.__class__.__name__} Parameter Summary (Top-Level Modules):" ) self.logger.info("-" * 80) total_params = 0 trainable_params = 0 for name, module in self.named_children(): # Only top-level children mod_total = sum(p.numel() for p in module.parameters()) mod_trainable = sum( p.numel() for p in module.parameters() if p.requires_grad ) total_params += mod_total trainable_params += mod_trainable self.logger.info( f"{name:<25} | Total: {mod_total:<20} | Trainable: {mod_trainable:,}" ) self.logger.info("-" * 80) self.logger.info(f"Total Parameters: {total_params:,}") self.logger.info(f"Trainable Parameters: {trainable_params:,}") def forward(self, x): x = self.patchify(x) # (B, N, patch_dim) x = self.proj(x) # (B, N, embed_dim) x = self.pos_encoding(x) # (B, N, embed_dim) x = self.encoder(x) # (B, N, embed_dim) x = self.decoder(x) # (B, N, patch_output_dim) x = self.unpatchify(x) # (B, out_ch, H, W) return x class ViT_Patch2Patch_ver2(nn.Module): """ SOme changes from above: - learned patch embed using a conv layer with kernelsize=patchsize - learned positional embedinng, no longer using sinusoidal - added some dropout - decoder: - replaced simple linear decoder from embed dim to output dim (pre patchify) - using PixelShuffle super resolution technique o https://docs.pytorch.org/docs/stable/generated/torch.nn.PixelShuffle.html """ def __init__( self, img_size=512, patch_size=8, in_ch=3, out_ch=3, embed_dim=512, depth=6, heads=8, dropout=0.0, ): super().__init__() self.img_size = img_size self.patch_size = patch_size self.num_patches = (img_size // patch_size) ** 2 self.embed_dim = embed_dim self.logger = logging.getLogger(self.__class__.__name__) self.logger.info( f"Initialized {self.__class__.__name__} with img_size={img_size}, patch_size={patch_size}" ) # Patch embedding via conv self.patch_embed = nn.Conv2d( in_ch, embed_dim, kernel_size=patch_size, stride=patch_size ) self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches, embed_dim)) self.pos_dropout = nn.Dropout(p=dropout) # Transformer encoder encoder_layer = nn.TransformerEncoderLayer( d_model=embed_dim, nhead=heads, dim_feedforward=embed_dim * 4, dropout=dropout, batch_first=True, ) self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=depth) # Decoder: maps transformer output back to full-resolution image self.decoder = nn.Sequential( # IF patchsize = 16 nn.Conv2d( embed_dim, 512, 3, padding=1 ), # [B, 768, 32, 32] -> [B, 512, 32, 32] nn.ReLU(), nn.Conv2d( 512, 1024, 3, padding=1 ), # [B, 512, 32, 32] -> [B, 1024, 32, 32] nn.ReLU(), nn.PixelShuffle(4), # [B, 1024, 32, 32] -> [B, 64, 128, 128] nn.Conv2d(64, 64, 3, padding=1), # [B, 64, 128, 128] -> [B, 64, 128, 128] nn.ReLU(), nn.Conv2d(64, 64, 3, padding=1), # [B, 64, 128, 128] -> [B, 64, 128, 128] nn.ReLU(), nn.PixelShuffle(4), # [B, 64, 128, 128] -> [B, 4, 512, 512] nn.Conv2d(4, out_ch, 3, padding=1), # [B, 4, 512, 512] -> [B, 3, 512, 512] ) self._log_parameter_count() def _log_parameter_count(self): """ Logs total and trainable parameters in the model, summarized by top-level modules. """ self.logger.info( f"{self.__class__.__name__} Parameter Summary (Top-Level Modules):" ) self.logger.info("-" * 80) total_params = 0 trainable_params = 0 for name, module in self.named_children(): # Only top-level children mod_total = sum(p.numel() for p in module.parameters()) mod_trainable = sum( p.numel() for p in module.parameters() if p.requires_grad ) total_params += mod_total trainable_params += mod_trainable self.logger.info( f"{name:<25} | Total: {mod_total:<20} | Trainable: {mod_trainable:,}" ) self.logger.info("-" * 80) self.logger.info(f"Total Parameters: {total_params:,}") self.logger.info(f"Trainable Parameters: {trainable_params:,}") def forward(self, x): B = x.size(0) # Patch embedding x = self.patch_embed(x) # [B, embed_dim, H//p, W//p] H_p, W_p = x.shape[2], x.shape[3] x = x.flatten(2).transpose(1, 2) # [B, N, embed_dim] # Add positional embedding x = x + self.pos_embed[:, : x.size(1), :] x = self.pos_dropout(x) # Transformer x = self.encoder(x) # [B, N, embed_dim] # Reshape back to 2D grid x = x.transpose(1, 2).reshape( B, self.embed_dim, H_p, W_p ) # [B, embed_dim, H//p, W//p] # Decode to full-res output out = self.decoder(x) # [B, out_ch, H, W] return out # ============================================================================================================ # TESTING # ============================================================================================================ def test_model(model, name, input_tensor): try: print(f"Testing {name}...") out = model(input_tensor) print(f"{name} output shape: {out.shape}\n") except Exception as e: print(f"{name} failed with error: {e}\n") def main(): logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") print("Starting model tests...\n") # Set image parameters B, C, H, W = 1, 3, 512, 512 dummy_input = torch.randn(B, C, H, W) # 1. ViT Patch2Patch (version 1) model1 = ViT_Patch2Patch(img_size=512, patch_size=8, in_ch=3, out_ch=3) test_model(model1, "ViT_Patch2Patch (version 1)", dummy_input) print("==" * 50) # 2. ViT Patch2Patch (version 2) model1 = ViT_Patch2Patch_ver2(img_size=512, patch_size=8, in_ch=3, out_ch=3) test_model(model1, "ViT_Patch2Patch (version 2)", dummy_input) print("==" * 50) if __name__ == "__main__": main()