import torch import torch.nn as nn import torch.nn.functional as F from torchvision import transforms from PIL import Image import gradio as gr import numpy as np from skimage.metrics import structural_similarity as ssim import requests from io import BytesIO # ========================================================== # Model Architecture (copied from training script) # ========================================================== class TernarySTE(torch.autograd.Function): @staticmethod def forward(ctx, x, temp): ctx.save_for_backward(x, temp) return torch.sign(x) * (x.abs() > 1).to(x.dtype) @staticmethod def backward(ctx, grad_output): x, temp = ctx.saved_tensors def sigmoid_derivative(z): s = torch.sigmoid(z) return s * (1.0 - s) surrogate_grad = (sigmoid_derivative((x - 1.0) / temp) + sigmoid_derivative((x + 1.0) / temp)) / temp grad_x = grad_output * surrogate_grad return grad_x, None class AdaptiveBitwiseSign(nn.Module): def __init__(self, initial_temp=1.0): super().__init__() self.register_buffer('temp', torch.tensor(initial_temp, dtype=torch.float32)) def forward(self, x): return TernarySTE.apply(x, self.temp) def anneal_temp(self, factor=0.98): self.temp.mul_(factor).clamp_(min=0.99) class InceptionDWConv2d(nn.Module): def __init__(self, in_channels, square_kernel_size=3, band_kernel_size=11, branch_ratio=0.125): super().__init__() gc = int(in_channels * branch_ratio) self.dwconv_hw = nn.Conv2d( gc, gc, square_kernel_size, padding=square_kernel_size // 2, groups=gc, padding_mode='reflect' ) self.dwconv_w = nn.Conv2d( gc, gc, kernel_size=(1, band_kernel_size), padding=(0, band_kernel_size // 2), groups=gc, padding_mode='reflect' ) self.dwconv_h = nn.Conv2d( gc, gc, kernel_size=(band_kernel_size, 1), padding=(band_kernel_size // 2, 0), groups=gc, padding_mode='reflect' ) self.split_indexes = (gc, gc, gc, in_channels - 3 * gc) def forward(self, x): x_hw, x_w, x_h, x_id = torch.split(x, self.split_indexes, dim=1) return torch.cat( (self.dwconv_hw(x_hw), self.dwconv_w(x_w), self.dwconv_h(x_h), x_id), dim=1 ) class InceptionNeXtBlock(nn.Module): def __init__(self, dim, expansion_ratio=4): super().__init__() self.token_mixer = InceptionDWConv2d(dim) self.norm = nn.BatchNorm2d(dim) hidden_dim = int(dim * expansion_ratio) self.mlp = nn.Sequential( nn.Conv2d(dim, hidden_dim, 1), nn.GELU(), nn.Conv2d(hidden_dim, dim, 1) ) def forward(self, x): shortcut = x x = self.token_mixer(x) x = self.norm(x) x = self.mlp(x) return x + shortcut class TransformerBlock(nn.Module): def __init__(self, dim, num_heads=8, mlp_ratio=4.0, dropout=0.0): super().__init__() self.norm1 = nn.LayerNorm(dim) self.attn = nn.MultiheadAttention(dim, num_heads, dropout=dropout, batch_first=True) self.norm2 = nn.LayerNorm(dim) mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = nn.Sequential( nn.Linear(dim, mlp_hidden_dim), nn.GELU(), nn.Dropout(dropout), nn.Linear(mlp_hidden_dim, dim), nn.Dropout(dropout) ) def forward(self, x): B, C, H, W = x.shape x_flat = x.flatten(2).transpose(1, 2) x_norm = self.norm1(x_flat) attn_out, _ = self.attn(x_norm, x_norm, x_norm) x_flat = x_flat + attn_out x_flat = x_flat + self.mlp(self.norm2(x_flat)) x = x_flat.transpose(1, 2).reshape(B, C, H, W) return x class MultiScaleEncoder(nn.Module): def __init__(self, patch=4, dims=[128, 256, 512, 1024], depths=[4, 4, 4, 4], latent_discrete=128, latent_continuous=64): super().__init__() self.unshuffle = nn.PixelUnshuffle(patch) self.stem = nn.Conv2d(3 * patch * patch, dims[0], 7, padding=3, padding_mode='reflect') self.initial_blocks = nn.Sequential( *[InceptionNeXtBlock(dims[0]) for _ in range(depths[0])] ) self.stages = nn.ModuleList() for i in range(len(dims)-1): downsample = nn.Sequential( nn.PixelUnshuffle(2), nn.Conv2d(dims[i] * 4, dims[i+1], 1) ) if i == len(dims) - 2: blocks = nn.Sequential( *[InceptionNeXtBlock(dims[i+1]) for _ in range(depths[i+1])], *[TransformerBlock(dims[i+1], num_heads=16) for _ in range(8)] ) else: blocks = nn.Sequential( *[InceptionNeXtBlock(dims[i+1]) for _ in range(depths[i+1])] ) self.stages.append(nn.ModuleList([downsample, blocks])) self.to_latent_discrete = nn.Conv2d(dims[-1], latent_discrete, 3, padding=1, padding_mode='reflect') self.to_latent_continuous = nn.Conv2d(dims[-1], latent_continuous, 3, padding=1, padding_mode='reflect') self.to_latent = self.to_latent_discrete self.quant = AdaptiveBitwiseSign() def forward(self, x): x = self.unshuffle(x) x = self.stem(x) x = self.initial_blocks(x) for downsample, blocks in self.stages: x = downsample(x) x = blocks(x) z_discrete = self.quant(self.to_latent_discrete(x)) z_continuous = self.to_latent_continuous(x) return z_discrete, z_continuous class MultiScaleDecoder(nn.Module): def __init__(self, patch=4, dims=[128, 256, 512, 1024], depths=[4, 4, 4, 4], latent_discrete=128, latent_continuous=64): super().__init__() total_latent = latent_discrete + latent_continuous self.from_latent = nn.Conv2d(total_latent, dims[-1], 1) self.initial_blocks = nn.Sequential( *[InceptionNeXtBlock(dims[-1]) for _ in range(depths[-1])], *[TransformerBlock(dims[-1], num_heads=16) for _ in range(8)] ) self.stages = nn.ModuleList() for i in range(len(dims)-1, 0, -1): upsample = nn.Sequential( nn.Conv2d(dims[i], dims[i-1] * 4, 1), nn.PixelShuffle(2) ) blocks = nn.Sequential( *[InceptionNeXtBlock(dims[i-1]) for _ in range(depths[i-1])] ) self.stages.append(nn.ModuleList([upsample, blocks])) self.to_pixels = nn.Conv2d(dims[0], 3 * patch * patch, 3, padding=1, padding_mode='reflect') self.shuffle = nn.PixelShuffle(patch) def forward(self, z_discrete, z_continuous, return_feat=False): z = torch.cat([z_discrete, z_continuous], dim=1) x = self.from_latent(z) x = self.initial_blocks(x) feat = x for upsample, blocks in self.stages: x = upsample(x) x = blocks(x) img = self.shuffle(self.to_pixels(x)) if return_feat: return img, feat return img class BinaryAutoencoder(nn.Module): def __init__(self, latent_discrete=128, latent_continuous=64): super().__init__() self.encoder = MultiScaleEncoder(latent_discrete=latent_discrete, latent_continuous=latent_continuous) self.decoder = MultiScaleDecoder(latent_discrete=latent_discrete, latent_continuous=latent_continuous) self.dino_head = None self.latent_discrete = latent_discrete self.latent_continuous = latent_continuous def encode(self, x): return self.encoder(x) def decode(self, z_discrete, z_continuous, return_feat=False): if return_feat: recon, feat = self.decoder(z_discrete, z_continuous, return_feat=True) return torch.clamp(recon, -1, 1), feat recon = self.decoder(z_discrete, z_continuous) return torch.clamp(recon, -1, 1) def forward(self, x): z_discrete, z_continuous = self.encode(x) recon = self.decode(z_discrete, z_continuous) return recon, z_discrete, z_continuous # ========================================================== # Metrics # ========================================================== def compute_psnr(pred, target): """Compute PSNR for images in [-1, 1] range""" pred = (pred + 1) / 2 target = (target + 1) / 2 mse = F.mse_loss(pred, target) psnr = 10 * torch.log10(1.0 / (mse + 1e-8)) return psnr.item() def compute_ssim(pred, target): """Compute SSIM for images in [-1, 1] range""" # Convert to [0, 1] range and move to CPU numpy pred = ((pred + 1) / 2).squeeze(0).permute(1, 2, 0).cpu().numpy() target = ((target + 1) / 2).squeeze(0).permute(1, 2, 0).cpu().numpy() # Compute SSIM with multichannel for RGB return ssim(target, pred, multichannel=True, channel_axis=2, data_range=1.0) # ========================================================== # Preprocessing and Demo Functions # ========================================================== def resize_and_crop(image, min_size=512, multiple=32): """Resizes shortest side to min_size, then crops so edges are divisible by `multiple`.""" w, h = image.size # Scale shortest side to min_size if w < h: new_w = min_size new_h = int(h * (min_size / w)) else: new_h = min_size new_w = int(w * (min_size / h)) image = image.resize((new_w, new_h), Image.Resampling.BILINEAR) # Make divisible by target multiple (32 for VAE downsampling) final_w = (new_w // multiple) * multiple final_h = (new_h // multiple) * multiple # Center crop left = (new_w - final_w) // 2 top = (new_h - final_h) // 2 right = left + final_w bottom = top + final_h return image.crop((left, top, right, bottom)) def load_model(device='cpu'): """Load the pretrained model from HuggingFace""" print("Loading model...") # Updated to latent_continuous=32 and latent_discrete=256 model = BinaryAutoencoder(latent_discrete=256, latent_continuous=32).to(device) # Download checkpoint url = "https://huggingface.co/Shio-Koube/vae_binary/resolve/main/latest.pt" response = requests.get(url) checkpoint = torch.load(BytesIO(response.content), map_location=device) model.load_state_dict(checkpoint['model'], strict=False) model.eval() print("Model loaded successfully!") return model def reconstruct_with_channels(model, img_tensor, num_channels, device='cpu'): """Reconstruct image using only the first num_channels of continuous latent""" with torch.no_grad(): # Encode z_discrete, z_continuous = model.encode(img_tensor) # Zero out channels beyond num_channels if num_channels < z_continuous.shape[1]: z_continuous_masked = z_continuous.clone() z_continuous_masked[:, num_channels:, :, :] = 0 else: z_continuous_masked = z_continuous # Decode recon = model.decode(z_discrete, z_continuous_masked) return recon def process_image(image): """Process uploaded image and generate reconstructions with different channel counts""" device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Load model (cache it in practice) if not hasattr(process_image, 'model'): process_image.model = load_model(device) model = process_image.model # Preprocess image: Resize shortest side to 512, ensure divisible by 32 processed_image = resize_and_crop(image, min_size=512, multiple=32) transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize([0.5]*3, [0.5]*3), ]) img_tensor = transform(processed_image).unsqueeze(0).to(device) # Channel counts to test: 4 images total for a 2x2 grid all_channels = [0, 8, 16, 32] results = [] metrics_data = [] # Original image for comparison original_np = ((img_tensor[0] + 1) / 2).permute(1, 2, 0).cpu().numpy() original_pil = Image.fromarray((original_np * 255).astype(np.uint8)) for num_ch in all_channels: # Reconstruct recon = reconstruct_with_channels(model, img_tensor, num_ch, device) # Convert to PIL recon_np = ((recon[0] + 1) / 2).permute(1, 2, 0).cpu().numpy() recon_pil = Image.fromarray((recon_np * 255).astype(np.uint8)) # Compute metrics psnr = compute_psnr(recon, img_tensor) ssim_val = compute_ssim(recon, img_tensor) results.append(recon_pil) metrics_data.append([ f'{num_ch}ch', f'{psnr:.2f}', f'{ssim_val:.4f}' ]) # Create 2x2 grid dynamically based on the final image dimensions cell_width, cell_height = results[0].size grid_width = 2 grid_height = 2 grid_img = Image.new('RGB', (cell_width * grid_width, cell_height * grid_height)) for idx, img in enumerate(results[:4]): row = idx // grid_width col = idx % grid_width grid_img.paste(img, (col * cell_width, row * cell_height)) return original_pil, grid_img, metrics_data # ========================================================== # Gradio Interface # ========================================================== with gr.Blocks(title="Binary VAE with Continuous Channels Demo") as demo: gr.Markdown(""" # Binary VAE with Continuous Channels This demo shows how the reconstruction quality improves as we use more continuous latent channels. The model uses 256 discrete (ternary) channels + 0-32 continuous channels. Input images are resized so their shortest edge is 512, and then cropped so dimensions are perfectly divisible by 32. **Channel counts tested:** 0, 8, 16, 32 """) with gr.Row(): with gr.Column(): input_image = gr.Image(type="pil", label="Input Image") submit_btn = gr.Button("Generate Reconstructions", variant="primary") with gr.Column(): original_output = gr.Image(label="Original (preprocessed)") gr.Markdown("## Reconstructions (2x2 Grid)") gr.Markdown("**Top row:** 0ch, 8ch | **Bottom row:** 16ch, 32ch") grid_output = gr.Image(label="Reconstructions Grid") gr.Markdown("## Metrics") metrics_table = gr.Dataframe( headers=['Channels', 'PSNR (dB)', 'SSIM'], label="Quality Metrics" ) submit_btn.click( fn=process_image, inputs=[input_image], outputs=[original_output, grid_output, metrics_table] ) gr.Markdown(""" ### Notes: - **PSNR**: Peak Signal-to-Noise Ratio (higher is better, >30 dB is good) - **SSIM**: Structural Similarity Index (0-1, higher is better) - The model has 32 continuous channels max. - Discrete latent (256 ternary channels) is always active. """) if __name__ == "__main__": demo.launch()