maxhuber's picture
Upload 14 files
3336231 verified
"""
Patch-to-Patch Vision Transformer Models
----------------------------------------
This module implements two variants of Vision Transformer (ViT) architectures
for dense regression tasks, designed around a patch-to-patch learning paradigm.
Both models decompose images into non-overlapping patches, process them through
a transformer encoder, and then reconstruct the output image.
Components:
- Patchify:
Splits an input image into flattened non-overlapping patches.
- Unpatchify:
Reconstructs an image tensor from a sequence of flattened patches.
- ViT_Patch2Patch (Version 1):
• Patchify + linear projection into embedding space.
• Sinusoidal positional encoding.
• Transformer encoder with configurable depth/heads.
• Linear decoder back to patch space, then Unpatchify.
- ViT_Patch2Patch_ver2 (Version 2):
• Patch embedding via Conv2d (stride = patch size).
• Learned positional embeddings with dropout.
• Transformer encoder with configurable depth/heads.
• CNN-based decoder with PixelShuffle layers for super-resolution-style
upsampling back to the original image resolution.
Utilities:
- test_model:
Simple wrapper to run a forward pass and log output shapes.
- main:
Runs lightweight tests of both model variants on a dummy input tensor.
Usage Example:
>>> import torch
>>> from vit_patch2patch import ViT_Patch2Patch, ViT_Patch2Patch_ver2
>>> model = ViT_Patch2Patch(img_size=512, patch_size=8, in_ch=3, out_ch=3)
>>> dummy = torch.randn(1, 3, 512, 512)
>>> out = model(dummy)
>>> print(out.shape) # torch.Size([1, 3, 512, 512])
Notes:
- Logging is used to track initialization parameters and parameter counts.
- Default settings assume square images (H = W = img_size).
- PixelShuffle decoder in Version 2 assumes patch_size divisible by upscaling factors.
"""
import logging
import torch
from torch import Tensor
import torch.nn as nn
# import timm
##############################################################################################################
class PositionalEncoding(nn.Module):
def __init__(self, emb_size: int, max_len: int = 1000):
"""
Sinusoidal Positional Encoding Module.
Args:
emb_size (int): The size of the embedding dimension.
max_len (int): The maximum length of the sequence.
"""
super(PositionalEncoding, self).__init__()
self.logger = logging.getLogger(self.__class__.__name__)
self.logger.info(
f"Initializing PositionalEncoding with emb_size={emb_size}, max_len={max_len}"
)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, emb_size, 2).float() * (-torch.log(10000.0) / emb_size)
)
pe = torch.zeros(max_len, emb_size)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
self.register_buffer("positional_encoding", pe.unsqueeze(0))
def forward(self, x: Tensor) -> Tensor:
"""
Add positional encoding to the input tensor.
Args:
x (Tensor): Input tensor of shape (batch_size, seq_len, emb_size).
Returns:
Tensor: Positional encoded tensor of the same shape as input.
"""
seq_len = x.size(1)
self.logger.debug(f"Adding positional encoding to tensor of shape {x.shape}")
# Ensure positional encoding is on the same device as x
self.positional_encoding = self.positional_encoding.to(x.device)
return x + self.positional_encoding[:, :seq_len, :]
##############################################################################################################
class Patchify(nn.Module):
def __init__(self, patch_size):
super().__init__()
self.patch_size = patch_size
self.logger = logging.getLogger(self.__class__.__name__)
self.logger.info(
f"Initializing {self.__class__.__name__} | patch_size = {self.patch_size}"
)
def forward(self, x): # (B, C, H, W)
B, C, H, W = x.shape
p = self.patch_size
assert H % p == 0 and W % p == 0
x = x.unfold(2, p, p).unfold(3, p, p) # B, C, H//p, W//p, p, p
x = x.permute(0, 2, 3, 1, 4, 5).flatten(1, 3) # B, N, C, p, p
return x.reshape(B, -1, C * p * p) # B, N, patch_dim
class Unpatchify(nn.Module):
def __init__(self, patch_size, out_channels, image_size):
super().__init__()
self.patch_size = patch_size
self.out_channels = out_channels
self.image_size = image_size
self.logger = logging.getLogger(self.__class__.__name__)
self.logger.info(
f"Initializing {self.__class__.__name__} | patch_size = {self.patch_size} | out_channels={out_channels} | image_size={image_size}"
)
def forward(self, x): # (B, N, patch_dim)
B, N, D = x.shape
p = self.patch_size
H, W = self.image_size
C = self.out_channels
x = x.reshape(B, H // p, W // p, C, p, p)
x = x.permute(0, 3, 1, 4, 2, 5).reshape(B, C, H, W)
return x
class ViT_Patch2Patch(nn.Module):
def __init__(
self,
img_size=512,
patch_size=8,
in_ch=3,
out_ch=3,
embed_dim=512,
depth=6,
heads=8,
):
super().__init__()
self.patch_size = patch_size
self.img_size = img_size
self.num_patches = (img_size // patch_size) ** 2
self.patch_dim = in_ch * patch_size * patch_size
self.output_dim = out_ch * patch_size * patch_size
self.logger = logging.getLogger(self.__class__.__name__)
self.logger.info(
f"Initalizing {self.__class__.__name__} | img_size={img_size} | patch_size={patch_size}"
"| in_ch={in_ch} | out_ch={out_ch} | embed_dim={embed_dim} | depth={depth} | heads={heads}"
)
# Modules
self.patchify = Patchify(patch_size)
self.proj = nn.Linear(self.patch_dim, embed_dim)
self.pos_encoding = PositionalEncoding(
emb_size=embed_dim, max_len=self.num_patches
)
encoder_layer = nn.TransformerEncoderLayer(
embed_dim, heads, dim_feedforward=embed_dim * 4, batch_first=True
)
self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=depth)
self.decoder = nn.Linear(embed_dim, self.output_dim)
self.unpatchify = Unpatchify(patch_size, out_ch, (img_size, img_size))
self._log_parameter_count()
def _log_parameter_count(self):
"""
Logs total and trainable parameters in the model, summarized by top-level modules.
"""
self.logger.info(
f"{self.__class__.__name__} Parameter Summary (Top-Level Modules):"
)
self.logger.info("-" * 80)
total_params = 0
trainable_params = 0
for name, module in self.named_children(): # Only top-level children
mod_total = sum(p.numel() for p in module.parameters())
mod_trainable = sum(
p.numel() for p in module.parameters() if p.requires_grad
)
total_params += mod_total
trainable_params += mod_trainable
self.logger.info(
f"{name:<25} | Total: {mod_total:<20} | Trainable: {mod_trainable:,}"
)
self.logger.info("-" * 80)
self.logger.info(f"Total Parameters: {total_params:,}")
self.logger.info(f"Trainable Parameters: {trainable_params:,}")
def forward(self, x):
x = self.patchify(x) # (B, N, patch_dim)
x = self.proj(x) # (B, N, embed_dim)
x = self.pos_encoding(x) # (B, N, embed_dim)
x = self.encoder(x) # (B, N, embed_dim)
x = self.decoder(x) # (B, N, patch_output_dim)
x = self.unpatchify(x) # (B, out_ch, H, W)
return x
class ViT_Patch2Patch_ver2(nn.Module):
"""
SOme changes from above:
- learned patch embed using a conv layer with kernelsize=patchsize
- learned positional embedinng, no longer using sinusoidal
- added some dropout
- decoder:
- replaced simple linear decoder from embed dim to output dim (pre patchify)
- using PixelShuffle super resolution technique
o https://docs.pytorch.org/docs/stable/generated/torch.nn.PixelShuffle.html
"""
def __init__(
self,
img_size=512,
patch_size=8,
in_ch=3,
out_ch=3,
embed_dim=512,
depth=6,
heads=8,
dropout=0.0,
):
super().__init__()
self.img_size = img_size
self.patch_size = patch_size
self.num_patches = (img_size // patch_size) ** 2
self.embed_dim = embed_dim
self.logger = logging.getLogger(self.__class__.__name__)
self.logger.info(
f"Initialized {self.__class__.__name__} with img_size={img_size}, patch_size={patch_size}"
)
# Patch embedding via conv
self.patch_embed = nn.Conv2d(
in_ch, embed_dim, kernel_size=patch_size, stride=patch_size
)
self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches, embed_dim))
self.pos_dropout = nn.Dropout(p=dropout)
# Transformer encoder
encoder_layer = nn.TransformerEncoderLayer(
d_model=embed_dim,
nhead=heads,
dim_feedforward=embed_dim * 4,
dropout=dropout,
batch_first=True,
)
self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=depth)
# Decoder: maps transformer output back to full-resolution image
self.decoder = nn.Sequential( # IF patchsize = 16
nn.Conv2d(
embed_dim, 512, 3, padding=1
), # [B, 768, 32, 32] -> [B, 512, 32, 32]
nn.ReLU(),
nn.Conv2d(
512, 1024, 3, padding=1
), # [B, 512, 32, 32] -> [B, 1024, 32, 32]
nn.ReLU(),
nn.PixelShuffle(4), # [B, 1024, 32, 32] -> [B, 64, 128, 128]
nn.Conv2d(64, 64, 3, padding=1), # [B, 64, 128, 128] -> [B, 64, 128, 128]
nn.ReLU(),
nn.Conv2d(64, 64, 3, padding=1), # [B, 64, 128, 128] -> [B, 64, 128, 128]
nn.ReLU(),
nn.PixelShuffle(4), # [B, 64, 128, 128] -> [B, 4, 512, 512]
nn.Conv2d(4, out_ch, 3, padding=1), # [B, 4, 512, 512] -> [B, 3, 512, 512]
)
self._log_parameter_count()
def _log_parameter_count(self):
"""
Logs total and trainable parameters in the model, summarized by top-level modules.
"""
self.logger.info(
f"{self.__class__.__name__} Parameter Summary (Top-Level Modules):"
)
self.logger.info("-" * 80)
total_params = 0
trainable_params = 0
for name, module in self.named_children(): # Only top-level children
mod_total = sum(p.numel() for p in module.parameters())
mod_trainable = sum(
p.numel() for p in module.parameters() if p.requires_grad
)
total_params += mod_total
trainable_params += mod_trainable
self.logger.info(
f"{name:<25} | Total: {mod_total:<20} | Trainable: {mod_trainable:,}"
)
self.logger.info("-" * 80)
self.logger.info(f"Total Parameters: {total_params:,}")
self.logger.info(f"Trainable Parameters: {trainable_params:,}")
def forward(self, x):
B = x.size(0)
# Patch embedding
x = self.patch_embed(x) # [B, embed_dim, H//p, W//p]
H_p, W_p = x.shape[2], x.shape[3]
x = x.flatten(2).transpose(1, 2) # [B, N, embed_dim]
# Add positional embedding
x = x + self.pos_embed[:, : x.size(1), :]
x = self.pos_dropout(x)
# Transformer
x = self.encoder(x) # [B, N, embed_dim]
# Reshape back to 2D grid
x = x.transpose(1, 2).reshape(
B, self.embed_dim, H_p, W_p
) # [B, embed_dim, H//p, W//p]
# Decode to full-res output
out = self.decoder(x) # [B, out_ch, H, W]
return out
# ============================================================================================================
# TESTING
# ============================================================================================================
def test_model(model, name, input_tensor):
try:
print(f"Testing {name}...")
out = model(input_tensor)
print(f"{name} output shape: {out.shape}\n")
except Exception as e:
print(f"{name} failed with error: {e}\n")
def main():
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
print("Starting model tests...\n")
# Set image parameters
B, C, H, W = 1, 3, 512, 512
dummy_input = torch.randn(B, C, H, W)
# 1. ViT Patch2Patch (version 1)
model1 = ViT_Patch2Patch(img_size=512, patch_size=8, in_ch=3, out_ch=3)
test_model(model1, "ViT_Patch2Patch (version 1)", dummy_input)
print("==" * 50)
# 2. ViT Patch2Patch (version 2)
model1 = ViT_Patch2Patch_ver2(img_size=512, patch_size=8, in_ch=3, out_ch=3)
test_model(model1, "ViT_Patch2Patch (version 2)", dummy_input)
print("==" * 50)
if __name__ == "__main__":
main()