vae_demo / app.py
Shio-Koube's picture
Update app.py
e060842 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image
import gradio as gr
import numpy as np
from skimage.metrics import structural_similarity as ssim
import requests
from io import BytesIO
# ==========================================================
# Model Architecture (copied from training script)
# ==========================================================
class TernarySTE(torch.autograd.Function):
@staticmethod
def forward(ctx, x, temp):
ctx.save_for_backward(x, temp)
return torch.sign(x) * (x.abs() > 1).to(x.dtype)
@staticmethod
def backward(ctx, grad_output):
x, temp = ctx.saved_tensors
def sigmoid_derivative(z):
s = torch.sigmoid(z)
return s * (1.0 - s)
surrogate_grad = (sigmoid_derivative((x - 1.0) / temp) +
sigmoid_derivative((x + 1.0) / temp)) / temp
grad_x = grad_output * surrogate_grad
return grad_x, None
class AdaptiveBitwiseSign(nn.Module):
def __init__(self, initial_temp=1.0):
super().__init__()
self.register_buffer('temp', torch.tensor(initial_temp, dtype=torch.float32))
def forward(self, x):
return TernarySTE.apply(x, self.temp)
def anneal_temp(self, factor=0.98):
self.temp.mul_(factor).clamp_(min=0.99)
class InceptionDWConv2d(nn.Module):
def __init__(self, in_channels, square_kernel_size=3, band_kernel_size=11, branch_ratio=0.125):
super().__init__()
gc = int(in_channels * branch_ratio)
self.dwconv_hw = nn.Conv2d(
gc, gc, square_kernel_size,
padding=square_kernel_size // 2, groups=gc, padding_mode='reflect'
)
self.dwconv_w = nn.Conv2d(
gc, gc, kernel_size=(1, band_kernel_size),
padding=(0, band_kernel_size // 2), groups=gc, padding_mode='reflect'
)
self.dwconv_h = nn.Conv2d(
gc, gc, kernel_size=(band_kernel_size, 1),
padding=(band_kernel_size // 2, 0), groups=gc, padding_mode='reflect'
)
self.split_indexes = (gc, gc, gc, in_channels - 3 * gc)
def forward(self, x):
x_hw, x_w, x_h, x_id = torch.split(x, self.split_indexes, dim=1)
return torch.cat(
(self.dwconv_hw(x_hw),
self.dwconv_w(x_w),
self.dwconv_h(x_h),
x_id),
dim=1
)
class InceptionNeXtBlock(nn.Module):
def __init__(self, dim, expansion_ratio=4):
super().__init__()
self.token_mixer = InceptionDWConv2d(dim)
self.norm = nn.BatchNorm2d(dim)
hidden_dim = int(dim * expansion_ratio)
self.mlp = nn.Sequential(
nn.Conv2d(dim, hidden_dim, 1),
nn.GELU(),
nn.Conv2d(hidden_dim, dim, 1)
)
def forward(self, x):
shortcut = x
x = self.token_mixer(x)
x = self.norm(x)
x = self.mlp(x)
return x + shortcut
class TransformerBlock(nn.Module):
def __init__(self, dim, num_heads=8, mlp_ratio=4.0, dropout=0.0):
super().__init__()
self.norm1 = nn.LayerNorm(dim)
self.attn = nn.MultiheadAttention(dim, num_heads, dropout=dropout, batch_first=True)
self.norm2 = nn.LayerNorm(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = nn.Sequential(
nn.Linear(dim, mlp_hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(mlp_hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
B, C, H, W = x.shape
x_flat = x.flatten(2).transpose(1, 2)
x_norm = self.norm1(x_flat)
attn_out, _ = self.attn(x_norm, x_norm, x_norm)
x_flat = x_flat + attn_out
x_flat = x_flat + self.mlp(self.norm2(x_flat))
x = x_flat.transpose(1, 2).reshape(B, C, H, W)
return x
class MultiScaleEncoder(nn.Module):
def __init__(self, patch=4, dims=[128, 256, 512, 1024], depths=[4, 4, 4, 4], latent_discrete=128, latent_continuous=64):
super().__init__()
self.unshuffle = nn.PixelUnshuffle(patch)
self.stem = nn.Conv2d(3 * patch * patch, dims[0], 7, padding=3, padding_mode='reflect')
self.initial_blocks = nn.Sequential(
*[InceptionNeXtBlock(dims[0]) for _ in range(depths[0])]
)
self.stages = nn.ModuleList()
for i in range(len(dims)-1):
downsample = nn.Sequential(
nn.PixelUnshuffle(2),
nn.Conv2d(dims[i] * 4, dims[i+1], 1)
)
if i == len(dims) - 2:
blocks = nn.Sequential(
*[InceptionNeXtBlock(dims[i+1]) for _ in range(depths[i+1])],
*[TransformerBlock(dims[i+1], num_heads=16) for _ in range(8)]
)
else:
blocks = nn.Sequential(
*[InceptionNeXtBlock(dims[i+1]) for _ in range(depths[i+1])]
)
self.stages.append(nn.ModuleList([downsample, blocks]))
self.to_latent_discrete = nn.Conv2d(dims[-1], latent_discrete, 3, padding=1, padding_mode='reflect')
self.to_latent_continuous = nn.Conv2d(dims[-1], latent_continuous, 3, padding=1, padding_mode='reflect')
self.to_latent = self.to_latent_discrete
self.quant = AdaptiveBitwiseSign()
def forward(self, x):
x = self.unshuffle(x)
x = self.stem(x)
x = self.initial_blocks(x)
for downsample, blocks in self.stages:
x = downsample(x)
x = blocks(x)
z_discrete = self.quant(self.to_latent_discrete(x))
z_continuous = self.to_latent_continuous(x)
return z_discrete, z_continuous
class MultiScaleDecoder(nn.Module):
def __init__(self, patch=4, dims=[128, 256, 512, 1024], depths=[4, 4, 4, 4], latent_discrete=128, latent_continuous=64):
super().__init__()
total_latent = latent_discrete + latent_continuous
self.from_latent = nn.Conv2d(total_latent, dims[-1], 1)
self.initial_blocks = nn.Sequential(
*[InceptionNeXtBlock(dims[-1]) for _ in range(depths[-1])],
*[TransformerBlock(dims[-1], num_heads=16) for _ in range(8)]
)
self.stages = nn.ModuleList()
for i in range(len(dims)-1, 0, -1):
upsample = nn.Sequential(
nn.Conv2d(dims[i], dims[i-1] * 4, 1),
nn.PixelShuffle(2)
)
blocks = nn.Sequential(
*[InceptionNeXtBlock(dims[i-1]) for _ in range(depths[i-1])]
)
self.stages.append(nn.ModuleList([upsample, blocks]))
self.to_pixels = nn.Conv2d(dims[0], 3 * patch * patch, 3, padding=1, padding_mode='reflect')
self.shuffle = nn.PixelShuffle(patch)
def forward(self, z_discrete, z_continuous, return_feat=False):
z = torch.cat([z_discrete, z_continuous], dim=1)
x = self.from_latent(z)
x = self.initial_blocks(x)
feat = x
for upsample, blocks in self.stages:
x = upsample(x)
x = blocks(x)
img = self.shuffle(self.to_pixels(x))
if return_feat:
return img, feat
return img
class BinaryAutoencoder(nn.Module):
def __init__(self, latent_discrete=128, latent_continuous=64):
super().__init__()
self.encoder = MultiScaleEncoder(latent_discrete=latent_discrete, latent_continuous=latent_continuous)
self.decoder = MultiScaleDecoder(latent_discrete=latent_discrete, latent_continuous=latent_continuous)
self.dino_head = None
self.latent_discrete = latent_discrete
self.latent_continuous = latent_continuous
def encode(self, x):
return self.encoder(x)
def decode(self, z_discrete, z_continuous, return_feat=False):
if return_feat:
recon, feat = self.decoder(z_discrete, z_continuous, return_feat=True)
return torch.clamp(recon, -1, 1), feat
recon = self.decoder(z_discrete, z_continuous)
return torch.clamp(recon, -1, 1)
def forward(self, x):
z_discrete, z_continuous = self.encode(x)
recon = self.decode(z_discrete, z_continuous)
return recon, z_discrete, z_continuous
# ==========================================================
# Metrics
# ==========================================================
def compute_psnr(pred, target):
"""Compute PSNR for images in [-1, 1] range"""
pred = (pred + 1) / 2
target = (target + 1) / 2
mse = F.mse_loss(pred, target)
psnr = 10 * torch.log10(1.0 / (mse + 1e-8))
return psnr.item()
def compute_ssim(pred, target):
"""Compute SSIM for images in [-1, 1] range"""
# Convert to [0, 1] range and move to CPU numpy
pred = ((pred + 1) / 2).squeeze(0).permute(1, 2, 0).cpu().numpy()
target = ((target + 1) / 2).squeeze(0).permute(1, 2, 0).cpu().numpy()
# Compute SSIM with multichannel for RGB
return ssim(target, pred, multichannel=True, channel_axis=2, data_range=1.0)
# ==========================================================
# Preprocessing and Demo Functions
# ==========================================================
def resize_and_crop(image, min_size=512, multiple=32):
"""Resizes shortest side to min_size, then crops so edges are divisible by `multiple`."""
w, h = image.size
# Scale shortest side to min_size
if w < h:
new_w = min_size
new_h = int(h * (min_size / w))
else:
new_h = min_size
new_w = int(w * (min_size / h))
image = image.resize((new_w, new_h), Image.Resampling.BILINEAR)
# Make divisible by target multiple (32 for VAE downsampling)
final_w = (new_w // multiple) * multiple
final_h = (new_h // multiple) * multiple
# Center crop
left = (new_w - final_w) // 2
top = (new_h - final_h) // 2
right = left + final_w
bottom = top + final_h
return image.crop((left, top, right, bottom))
def load_model(device='cpu'):
"""Load the pretrained model from HuggingFace"""
print("Loading model...")
# Updated to latent_continuous=32 and latent_discrete=256
model = BinaryAutoencoder(latent_discrete=256, latent_continuous=32).to(device)
# Download checkpoint
url = "https://huggingface.co/Shio-Koube/vae_binary/resolve/main/latest.pt"
response = requests.get(url)
checkpoint = torch.load(BytesIO(response.content), map_location=device)
model.load_state_dict(checkpoint['model'], strict=False)
model.eval()
print("Model loaded successfully!")
return model
def reconstruct_with_channels(model, img_tensor, num_channels, device='cpu'):
"""Reconstruct image using only the first num_channels of continuous latent"""
with torch.no_grad():
# Encode
z_discrete, z_continuous = model.encode(img_tensor)
# Zero out channels beyond num_channels
if num_channels < z_continuous.shape[1]:
z_continuous_masked = z_continuous.clone()
z_continuous_masked[:, num_channels:, :, :] = 0
else:
z_continuous_masked = z_continuous
# Decode
recon = model.decode(z_discrete, z_continuous_masked)
return recon
def process_image(image):
"""Process uploaded image and generate reconstructions with different channel counts"""
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Load model (cache it in practice)
if not hasattr(process_image, 'model'):
process_image.model = load_model(device)
model = process_image.model
# Preprocess image: Resize shortest side to 512, ensure divisible by 32
processed_image = resize_and_crop(image, min_size=512, multiple=32)
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.5]*3, [0.5]*3),
])
img_tensor = transform(processed_image).unsqueeze(0).to(device)
# Channel counts to test: 4 images total for a 2x2 grid
all_channels = [0, 8, 16, 32]
results = []
metrics_data = []
# Original image for comparison
original_np = ((img_tensor[0] + 1) / 2).permute(1, 2, 0).cpu().numpy()
original_pil = Image.fromarray((original_np * 255).astype(np.uint8))
for num_ch in all_channels:
# Reconstruct
recon = reconstruct_with_channels(model, img_tensor, num_ch, device)
# Convert to PIL
recon_np = ((recon[0] + 1) / 2).permute(1, 2, 0).cpu().numpy()
recon_pil = Image.fromarray((recon_np * 255).astype(np.uint8))
# Compute metrics
psnr = compute_psnr(recon, img_tensor)
ssim_val = compute_ssim(recon, img_tensor)
results.append(recon_pil)
metrics_data.append([
f'{num_ch}ch',
f'{psnr:.2f}',
f'{ssim_val:.4f}'
])
# Create 2x2 grid dynamically based on the final image dimensions
cell_width, cell_height = results[0].size
grid_width = 2
grid_height = 2
grid_img = Image.new('RGB', (cell_width * grid_width, cell_height * grid_height))
for idx, img in enumerate(results[:4]):
row = idx // grid_width
col = idx % grid_width
grid_img.paste(img, (col * cell_width, row * cell_height))
return original_pil, grid_img, metrics_data
# ==========================================================
# Gradio Interface
# ==========================================================
with gr.Blocks(title="Binary VAE with Continuous Channels Demo") as demo:
gr.Markdown("""
# Binary VAE with Continuous Channels
This demo shows how the reconstruction quality improves as we use more continuous latent channels.
The model uses 256 discrete (ternary) channels + 0-32 continuous channels. Input images are resized so their shortest edge is 512, and then cropped so dimensions are perfectly divisible by 32.
**Channel counts tested:** 0, 8, 16, 32
""")
with gr.Row():
with gr.Column():
input_image = gr.Image(type="pil", label="Input Image")
submit_btn = gr.Button("Generate Reconstructions", variant="primary")
with gr.Column():
original_output = gr.Image(label="Original (preprocessed)")
gr.Markdown("## Reconstructions (2x2 Grid)")
gr.Markdown("**Top row:** 0ch, 8ch | **Bottom row:** 16ch, 32ch")
grid_output = gr.Image(label="Reconstructions Grid")
gr.Markdown("## Metrics")
metrics_table = gr.Dataframe(
headers=['Channels', 'PSNR (dB)', 'SSIM'],
label="Quality Metrics"
)
submit_btn.click(
fn=process_image,
inputs=[input_image],
outputs=[original_output, grid_output, metrics_table]
)
gr.Markdown("""
### Notes:
- **PSNR**: Peak Signal-to-Noise Ratio (higher is better, >30 dB is good)
- **SSIM**: Structural Similarity Index (0-1, higher is better)
- The model has 32 continuous channels max.
- Discrete latent (256 ternary channels) is always active.
""")
if __name__ == "__main__":
demo.launch()