Spaces:
Running
Running
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torchvision import transforms | |
| from PIL import Image | |
| import gradio as gr | |
| import numpy as np | |
| from skimage.metrics import structural_similarity as ssim | |
| import requests | |
| from io import BytesIO | |
| # ========================================================== | |
| # Model Architecture (copied from training script) | |
| # ========================================================== | |
| class TernarySTE(torch.autograd.Function): | |
| def forward(ctx, x, temp): | |
| ctx.save_for_backward(x, temp) | |
| return torch.sign(x) * (x.abs() > 1).to(x.dtype) | |
| def backward(ctx, grad_output): | |
| x, temp = ctx.saved_tensors | |
| def sigmoid_derivative(z): | |
| s = torch.sigmoid(z) | |
| return s * (1.0 - s) | |
| surrogate_grad = (sigmoid_derivative((x - 1.0) / temp) + | |
| sigmoid_derivative((x + 1.0) / temp)) / temp | |
| grad_x = grad_output * surrogate_grad | |
| return grad_x, None | |
| class AdaptiveBitwiseSign(nn.Module): | |
| def __init__(self, initial_temp=1.0): | |
| super().__init__() | |
| self.register_buffer('temp', torch.tensor(initial_temp, dtype=torch.float32)) | |
| def forward(self, x): | |
| return TernarySTE.apply(x, self.temp) | |
| def anneal_temp(self, factor=0.98): | |
| self.temp.mul_(factor).clamp_(min=0.99) | |
| class InceptionDWConv2d(nn.Module): | |
| def __init__(self, in_channels, square_kernel_size=3, band_kernel_size=11, branch_ratio=0.125): | |
| super().__init__() | |
| gc = int(in_channels * branch_ratio) | |
| self.dwconv_hw = nn.Conv2d( | |
| gc, gc, square_kernel_size, | |
| padding=square_kernel_size // 2, groups=gc, padding_mode='reflect' | |
| ) | |
| self.dwconv_w = nn.Conv2d( | |
| gc, gc, kernel_size=(1, band_kernel_size), | |
| padding=(0, band_kernel_size // 2), groups=gc, padding_mode='reflect' | |
| ) | |
| self.dwconv_h = nn.Conv2d( | |
| gc, gc, kernel_size=(band_kernel_size, 1), | |
| padding=(band_kernel_size // 2, 0), groups=gc, padding_mode='reflect' | |
| ) | |
| self.split_indexes = (gc, gc, gc, in_channels - 3 * gc) | |
| def forward(self, x): | |
| x_hw, x_w, x_h, x_id = torch.split(x, self.split_indexes, dim=1) | |
| return torch.cat( | |
| (self.dwconv_hw(x_hw), | |
| self.dwconv_w(x_w), | |
| self.dwconv_h(x_h), | |
| x_id), | |
| dim=1 | |
| ) | |
| class InceptionNeXtBlock(nn.Module): | |
| def __init__(self, dim, expansion_ratio=4): | |
| super().__init__() | |
| self.token_mixer = InceptionDWConv2d(dim) | |
| self.norm = nn.BatchNorm2d(dim) | |
| hidden_dim = int(dim * expansion_ratio) | |
| self.mlp = nn.Sequential( | |
| nn.Conv2d(dim, hidden_dim, 1), | |
| nn.GELU(), | |
| nn.Conv2d(hidden_dim, dim, 1) | |
| ) | |
| def forward(self, x): | |
| shortcut = x | |
| x = self.token_mixer(x) | |
| x = self.norm(x) | |
| x = self.mlp(x) | |
| return x + shortcut | |
| class TransformerBlock(nn.Module): | |
| def __init__(self, dim, num_heads=8, mlp_ratio=4.0, dropout=0.0): | |
| super().__init__() | |
| self.norm1 = nn.LayerNorm(dim) | |
| self.attn = nn.MultiheadAttention(dim, num_heads, dropout=dropout, batch_first=True) | |
| self.norm2 = nn.LayerNorm(dim) | |
| mlp_hidden_dim = int(dim * mlp_ratio) | |
| self.mlp = nn.Sequential( | |
| nn.Linear(dim, mlp_hidden_dim), | |
| nn.GELU(), | |
| nn.Dropout(dropout), | |
| nn.Linear(mlp_hidden_dim, dim), | |
| nn.Dropout(dropout) | |
| ) | |
| def forward(self, x): | |
| B, C, H, W = x.shape | |
| x_flat = x.flatten(2).transpose(1, 2) | |
| x_norm = self.norm1(x_flat) | |
| attn_out, _ = self.attn(x_norm, x_norm, x_norm) | |
| x_flat = x_flat + attn_out | |
| x_flat = x_flat + self.mlp(self.norm2(x_flat)) | |
| x = x_flat.transpose(1, 2).reshape(B, C, H, W) | |
| return x | |
| class MultiScaleEncoder(nn.Module): | |
| def __init__(self, patch=4, dims=[128, 256, 512, 1024], depths=[4, 4, 4, 4], latent_discrete=128, latent_continuous=64): | |
| super().__init__() | |
| self.unshuffle = nn.PixelUnshuffle(patch) | |
| self.stem = nn.Conv2d(3 * patch * patch, dims[0], 7, padding=3, padding_mode='reflect') | |
| self.initial_blocks = nn.Sequential( | |
| *[InceptionNeXtBlock(dims[0]) for _ in range(depths[0])] | |
| ) | |
| self.stages = nn.ModuleList() | |
| for i in range(len(dims)-1): | |
| downsample = nn.Sequential( | |
| nn.PixelUnshuffle(2), | |
| nn.Conv2d(dims[i] * 4, dims[i+1], 1) | |
| ) | |
| if i == len(dims) - 2: | |
| blocks = nn.Sequential( | |
| *[InceptionNeXtBlock(dims[i+1]) for _ in range(depths[i+1])], | |
| *[TransformerBlock(dims[i+1], num_heads=16) for _ in range(8)] | |
| ) | |
| else: | |
| blocks = nn.Sequential( | |
| *[InceptionNeXtBlock(dims[i+1]) for _ in range(depths[i+1])] | |
| ) | |
| self.stages.append(nn.ModuleList([downsample, blocks])) | |
| self.to_latent_discrete = nn.Conv2d(dims[-1], latent_discrete, 3, padding=1, padding_mode='reflect') | |
| self.to_latent_continuous = nn.Conv2d(dims[-1], latent_continuous, 3, padding=1, padding_mode='reflect') | |
| self.to_latent = self.to_latent_discrete | |
| self.quant = AdaptiveBitwiseSign() | |
| def forward(self, x): | |
| x = self.unshuffle(x) | |
| x = self.stem(x) | |
| x = self.initial_blocks(x) | |
| for downsample, blocks in self.stages: | |
| x = downsample(x) | |
| x = blocks(x) | |
| z_discrete = self.quant(self.to_latent_discrete(x)) | |
| z_continuous = self.to_latent_continuous(x) | |
| return z_discrete, z_continuous | |
| class MultiScaleDecoder(nn.Module): | |
| def __init__(self, patch=4, dims=[128, 256, 512, 1024], depths=[4, 4, 4, 4], latent_discrete=128, latent_continuous=64): | |
| super().__init__() | |
| total_latent = latent_discrete + latent_continuous | |
| self.from_latent = nn.Conv2d(total_latent, dims[-1], 1) | |
| self.initial_blocks = nn.Sequential( | |
| *[InceptionNeXtBlock(dims[-1]) for _ in range(depths[-1])], | |
| *[TransformerBlock(dims[-1], num_heads=16) for _ in range(8)] | |
| ) | |
| self.stages = nn.ModuleList() | |
| for i in range(len(dims)-1, 0, -1): | |
| upsample = nn.Sequential( | |
| nn.Conv2d(dims[i], dims[i-1] * 4, 1), | |
| nn.PixelShuffle(2) | |
| ) | |
| blocks = nn.Sequential( | |
| *[InceptionNeXtBlock(dims[i-1]) for _ in range(depths[i-1])] | |
| ) | |
| self.stages.append(nn.ModuleList([upsample, blocks])) | |
| self.to_pixels = nn.Conv2d(dims[0], 3 * patch * patch, 3, padding=1, padding_mode='reflect') | |
| self.shuffle = nn.PixelShuffle(patch) | |
| def forward(self, z_discrete, z_continuous, return_feat=False): | |
| z = torch.cat([z_discrete, z_continuous], dim=1) | |
| x = self.from_latent(z) | |
| x = self.initial_blocks(x) | |
| feat = x | |
| for upsample, blocks in self.stages: | |
| x = upsample(x) | |
| x = blocks(x) | |
| img = self.shuffle(self.to_pixels(x)) | |
| if return_feat: | |
| return img, feat | |
| return img | |
| class BinaryAutoencoder(nn.Module): | |
| def __init__(self, latent_discrete=128, latent_continuous=64): | |
| super().__init__() | |
| self.encoder = MultiScaleEncoder(latent_discrete=latent_discrete, latent_continuous=latent_continuous) | |
| self.decoder = MultiScaleDecoder(latent_discrete=latent_discrete, latent_continuous=latent_continuous) | |
| self.dino_head = None | |
| self.latent_discrete = latent_discrete | |
| self.latent_continuous = latent_continuous | |
| def encode(self, x): | |
| return self.encoder(x) | |
| def decode(self, z_discrete, z_continuous, return_feat=False): | |
| if return_feat: | |
| recon, feat = self.decoder(z_discrete, z_continuous, return_feat=True) | |
| return torch.clamp(recon, -1, 1), feat | |
| recon = self.decoder(z_discrete, z_continuous) | |
| return torch.clamp(recon, -1, 1) | |
| def forward(self, x): | |
| z_discrete, z_continuous = self.encode(x) | |
| recon = self.decode(z_discrete, z_continuous) | |
| return recon, z_discrete, z_continuous | |
| # ========================================================== | |
| # Metrics | |
| # ========================================================== | |
| def compute_psnr(pred, target): | |
| """Compute PSNR for images in [-1, 1] range""" | |
| pred = (pred + 1) / 2 | |
| target = (target + 1) / 2 | |
| mse = F.mse_loss(pred, target) | |
| psnr = 10 * torch.log10(1.0 / (mse + 1e-8)) | |
| return psnr.item() | |
| def compute_ssim(pred, target): | |
| """Compute SSIM for images in [-1, 1] range""" | |
| # Convert to [0, 1] range and move to CPU numpy | |
| pred = ((pred + 1) / 2).squeeze(0).permute(1, 2, 0).cpu().numpy() | |
| target = ((target + 1) / 2).squeeze(0).permute(1, 2, 0).cpu().numpy() | |
| # Compute SSIM with multichannel for RGB | |
| return ssim(target, pred, multichannel=True, channel_axis=2, data_range=1.0) | |
| # ========================================================== | |
| # Preprocessing and Demo Functions | |
| # ========================================================== | |
| def resize_and_crop(image, min_size=512, multiple=32): | |
| """Resizes shortest side to min_size, then crops so edges are divisible by `multiple`.""" | |
| w, h = image.size | |
| # Scale shortest side to min_size | |
| if w < h: | |
| new_w = min_size | |
| new_h = int(h * (min_size / w)) | |
| else: | |
| new_h = min_size | |
| new_w = int(w * (min_size / h)) | |
| image = image.resize((new_w, new_h), Image.Resampling.BILINEAR) | |
| # Make divisible by target multiple (32 for VAE downsampling) | |
| final_w = (new_w // multiple) * multiple | |
| final_h = (new_h // multiple) * multiple | |
| # Center crop | |
| left = (new_w - final_w) // 2 | |
| top = (new_h - final_h) // 2 | |
| right = left + final_w | |
| bottom = top + final_h | |
| return image.crop((left, top, right, bottom)) | |
| def load_model(device='cpu'): | |
| """Load the pretrained model from HuggingFace""" | |
| print("Loading model...") | |
| # Updated to latent_continuous=32 and latent_discrete=256 | |
| model = BinaryAutoencoder(latent_discrete=256, latent_continuous=32).to(device) | |
| # Download checkpoint | |
| url = "https://huggingface.co/Shio-Koube/vae_binary/resolve/main/latest.pt" | |
| response = requests.get(url) | |
| checkpoint = torch.load(BytesIO(response.content), map_location=device) | |
| model.load_state_dict(checkpoint['model'], strict=False) | |
| model.eval() | |
| print("Model loaded successfully!") | |
| return model | |
| def reconstruct_with_channels(model, img_tensor, num_channels, device='cpu'): | |
| """Reconstruct image using only the first num_channels of continuous latent""" | |
| with torch.no_grad(): | |
| # Encode | |
| z_discrete, z_continuous = model.encode(img_tensor) | |
| # Zero out channels beyond num_channels | |
| if num_channels < z_continuous.shape[1]: | |
| z_continuous_masked = z_continuous.clone() | |
| z_continuous_masked[:, num_channels:, :, :] = 0 | |
| else: | |
| z_continuous_masked = z_continuous | |
| # Decode | |
| recon = model.decode(z_discrete, z_continuous_masked) | |
| return recon | |
| def process_image(image): | |
| """Process uploaded image and generate reconstructions with different channel counts""" | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| # Load model (cache it in practice) | |
| if not hasattr(process_image, 'model'): | |
| process_image.model = load_model(device) | |
| model = process_image.model | |
| # Preprocess image: Resize shortest side to 512, ensure divisible by 32 | |
| processed_image = resize_and_crop(image, min_size=512, multiple=32) | |
| transform = transforms.Compose([ | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.5]*3, [0.5]*3), | |
| ]) | |
| img_tensor = transform(processed_image).unsqueeze(0).to(device) | |
| # Channel counts to test: 4 images total for a 2x2 grid | |
| all_channels = [0, 8, 16, 32] | |
| results = [] | |
| metrics_data = [] | |
| # Original image for comparison | |
| original_np = ((img_tensor[0] + 1) / 2).permute(1, 2, 0).cpu().numpy() | |
| original_pil = Image.fromarray((original_np * 255).astype(np.uint8)) | |
| for num_ch in all_channels: | |
| # Reconstruct | |
| recon = reconstruct_with_channels(model, img_tensor, num_ch, device) | |
| # Convert to PIL | |
| recon_np = ((recon[0] + 1) / 2).permute(1, 2, 0).cpu().numpy() | |
| recon_pil = Image.fromarray((recon_np * 255).astype(np.uint8)) | |
| # Compute metrics | |
| psnr = compute_psnr(recon, img_tensor) | |
| ssim_val = compute_ssim(recon, img_tensor) | |
| results.append(recon_pil) | |
| metrics_data.append([ | |
| f'{num_ch}ch', | |
| f'{psnr:.2f}', | |
| f'{ssim_val:.4f}' | |
| ]) | |
| # Create 2x2 grid dynamically based on the final image dimensions | |
| cell_width, cell_height = results[0].size | |
| grid_width = 2 | |
| grid_height = 2 | |
| grid_img = Image.new('RGB', (cell_width * grid_width, cell_height * grid_height)) | |
| for idx, img in enumerate(results[:4]): | |
| row = idx // grid_width | |
| col = idx % grid_width | |
| grid_img.paste(img, (col * cell_width, row * cell_height)) | |
| return original_pil, grid_img, metrics_data | |
| # ========================================================== | |
| # Gradio Interface | |
| # ========================================================== | |
| with gr.Blocks(title="Binary VAE with Continuous Channels Demo") as demo: | |
| gr.Markdown(""" | |
| # Binary VAE with Continuous Channels | |
| This demo shows how the reconstruction quality improves as we use more continuous latent channels. | |
| The model uses 256 discrete (ternary) channels + 0-32 continuous channels. Input images are resized so their shortest edge is 512, and then cropped so dimensions are perfectly divisible by 32. | |
| **Channel counts tested:** 0, 8, 16, 32 | |
| """) | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_image = gr.Image(type="pil", label="Input Image") | |
| submit_btn = gr.Button("Generate Reconstructions", variant="primary") | |
| with gr.Column(): | |
| original_output = gr.Image(label="Original (preprocessed)") | |
| gr.Markdown("## Reconstructions (2x2 Grid)") | |
| gr.Markdown("**Top row:** 0ch, 8ch | **Bottom row:** 16ch, 32ch") | |
| grid_output = gr.Image(label="Reconstructions Grid") | |
| gr.Markdown("## Metrics") | |
| metrics_table = gr.Dataframe( | |
| headers=['Channels', 'PSNR (dB)', 'SSIM'], | |
| label="Quality Metrics" | |
| ) | |
| submit_btn.click( | |
| fn=process_image, | |
| inputs=[input_image], | |
| outputs=[original_output, grid_output, metrics_table] | |
| ) | |
| gr.Markdown(""" | |
| ### Notes: | |
| - **PSNR**: Peak Signal-to-Noise Ratio (higher is better, >30 dB is good) | |
| - **SSIM**: Structural Similarity Index (0-1, higher is better) | |
| - The model has 32 continuous channels max. | |
| - Discrete latent (256 ternary channels) is always active. | |
| """) | |
| if __name__ == "__main__": | |
| demo.launch() |