""" E3Diff: High-Resolution SAR-to-Optical Translation HuggingFace Spaces Deployment Features: - Full resolution processing with seamless tiling - Multi-step inference for maximum quality - TIFF output support - Professional post-processing """ import os import sys import torch import torch.nn as nn import torch.nn.functional as F import numpy as np from PIL import Image, ImageEnhance import gradio as gr from pathlib import Path import tempfile import time from tqdm import tqdm from huggingface_hub import hf_hub_download # ============================================================================ # SoftPool Implementation (Pure PyTorch) # ============================================================================ def soft_pool2d(x, kernel_size=(2, 2), stride=None, force_inplace=False): if stride is None: stride = kernel_size if isinstance(kernel_size, int): kernel_size = (kernel_size, kernel_size) if isinstance(stride, int): stride = (stride, stride) batch, channels, height, width = x.shape kh, kw = kernel_size sh, sw = stride out_h = (height - kh) // sh + 1 out_w = (width - kw) // sw + 1 x_unfold = F.unfold(x, kernel_size=kernel_size, stride=stride) x_unfold = x_unfold.view(batch, channels, kh * kw, out_h * out_w) x_max = x_unfold.max(dim=2, keepdim=True)[0] exp_x = torch.exp(x_unfold - x_max) softpool = (x_unfold * exp_x).sum(dim=2) / (exp_x.sum(dim=2) + 1e-8) return softpool.view(batch, channels, out_h, out_w) class SoftPool2d(nn.Module): def __init__(self, kernel_size=(2, 2), stride=None, force_inplace=False): super(SoftPool2d, self).__init__() self.kernel_size = kernel_size if isinstance(kernel_size, tuple) else (kernel_size, kernel_size) self.stride = stride if stride is not None else self.kernel_size def forward(self, x): return soft_pool2d(x, self.kernel_size, self.stride) # Monkey-patch SoftPool into the expected location import sys class SoftPoolModule: soft_pool2d = staticmethod(soft_pool2d) SoftPool2d = SoftPool2d sys.modules['SoftPool'] = SoftPoolModule() # ============================================================================ # Model Architecture # ============================================================================ import math from inspect import isfunction def exists(x): return x is not None def default(val, d): if exists(val): return val return d() if isfunction(d) else d class PositionalEncoding(nn.Module): def __init__(self, dim): super().__init__() self.dim = dim def forward(self, noise_level): count = self.dim // 2 step = torch.arange(count, dtype=noise_level.dtype, device=noise_level.device) / count encoding = noise_level.unsqueeze(1) * torch.exp(-math.log(1e4) * step.unsqueeze(0)) encoding = torch.cat([torch.sin(encoding), torch.cos(encoding)], dim=-1) return encoding class Swish(nn.Module): def forward(self, x): return x * torch.sigmoid(x) class FeatureWiseAffine(nn.Module): def __init__(self, in_channels, out_channels, use_affine_level=False): super(FeatureWiseAffine, self).__init__() self.use_affine_level = use_affine_level self.noise_func = nn.Sequential(nn.Linear(in_channels, out_channels*(1+self.use_affine_level))) def forward(self, x, noise_embed): batch = x.shape[0] if self.use_affine_level: gamma, beta = self.noise_func(noise_embed).view(batch, -1, 1, 1).chunk(2, dim=1) x = (1 + gamma) * x + beta else: x = x + self.noise_func(noise_embed).view(batch, -1, 1, 1) return x class Upsample(nn.Module): def __init__(self, dim): super().__init__() self.up = nn.Upsample(scale_factor=2, mode="nearest") self.conv = nn.Conv2d(dim, dim, 3, padding=1) def forward(self, x): return self.conv(self.up(x)) class Downsample(nn.Module): def __init__(self, dim): super().__init__() self.conv = nn.Conv2d(dim, dim, 3, 2, 1) def forward(self, x): return self.conv(x) class Block(nn.Module): def __init__(self, dim, dim_out, groups=32, dropout=0, stride=1): super().__init__() self.block = nn.Sequential( nn.GroupNorm(groups, dim), Swish(), nn.Dropout(dropout) if dropout != 0 else nn.Identity(), nn.Conv2d(dim, dim_out, 3, stride=stride, padding=1) ) def forward(self, x): return self.block(x) class ResnetBlock(nn.Module): def __init__(self, dim, dim_out, noise_level_emb_dim=None, dropout=0, use_affine_level=False, norm_groups=32): super().__init__() self.noise_func = FeatureWiseAffine(noise_level_emb_dim, dim_out, use_affine_level) self.c_func = nn.Conv2d(dim_out, dim_out, 1) self.block1 = Block(dim, dim_out, groups=norm_groups) self.block2 = Block(dim_out, dim_out, groups=norm_groups, dropout=dropout) self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity() def forward(self, x, time_emb, c): h = self.block1(x) h = self.noise_func(h, time_emb) h = self.block2(h) h = self.c_func(c) + h return h + self.res_conv(x) class SelfAttention(nn.Module): def __init__(self, in_channel, n_head=1, norm_groups=32): super().__init__() self.n_head = n_head self.norm = nn.GroupNorm(norm_groups, in_channel) self.qkv = nn.Conv2d(in_channel, in_channel * 3, 1, bias=False) self.out = nn.Conv2d(in_channel, in_channel, 1) def forward(self, input, t=None, save_flag=None, file_num=None): batch, channel, height, width = input.shape n_head = self.n_head head_dim = channel // n_head norm = self.norm(input) qkv = self.qkv(norm).view(batch, n_head, head_dim * 3, height, width) query, key, value = qkv.chunk(3, dim=2) attn = torch.einsum("bnchw, bncyx -> bnhwyx", query, key).contiguous() / math.sqrt(channel) attn = attn.view(batch, n_head, height, width, -1) attn = torch.softmax(attn, -1) attn = attn.view(batch, n_head, height, width, height, width) out = torch.einsum("bnhwyx, bncyx -> bnchw", attn, value).contiguous() out = self.out(out.view(batch, channel, height, width)) return out + input class ResnetBlocWithAttn(nn.Module): def __init__(self, dim, dim_out, *, noise_level_emb_dim=None, norm_groups=32, dropout=0, with_attn=False, size=256): super().__init__() self.with_attn = with_attn self.res_block = ResnetBlock(dim, dim_out, noise_level_emb_dim, norm_groups=norm_groups, dropout=dropout) if with_attn: self.attn = SelfAttention(dim_out, norm_groups=norm_groups) def forward(self, x, time_emb, c, t=0, save_flag=False, file_i=0): x = self.res_block(x, time_emb, c) if self.with_attn: x = self.attn(x, t=t, save_flag=save_flag, file_num=file_i) return x class ResBlock_normal(nn.Module): def __init__(self, dim, dim_out, dropout=0, norm_groups=32): super().__init__() self.block1 = Block(dim, dim_out, groups=norm_groups) self.block2 = Block(dim_out, dim_out, groups=norm_groups, dropout=dropout) self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity() def forward(self, x): h = self.block1(x) h = self.block2(h) return h + self.res_conv(x) class CPEN(nn.Module): def __init__(self, inchannel=1): super(CPEN, self).__init__() self.pool = SoftPool2d(kernel_size=(2,2), stride=(2,2)) self.E1 = nn.Sequential(nn.Conv2d(inchannel, 64, kernel_size=3, padding=1), Swish()) self.E2 = nn.Sequential(ResBlock_normal(64, 128, dropout=0, norm_groups=16), ResBlock_normal(128, 128, dropout=0, norm_groups=16)) self.E3 = nn.Sequential(ResBlock_normal(128, 256, dropout=0, norm_groups=16), ResBlock_normal(256, 256, dropout=0, norm_groups=16)) self.E4 = nn.Sequential(ResBlock_normal(256, 512, dropout=0, norm_groups=16), ResBlock_normal(512, 512, dropout=0, norm_groups=16)) self.E5 = nn.Sequential(ResBlock_normal(512, 512, dropout=0, norm_groups=16), ResBlock_normal(512, 1024, dropout=0, norm_groups=16)) def forward(self, x): x1 = self.E1(x) x2 = self.pool(x1) x2 = self.E2(x2) x3 = self.pool(x2) x3 = self.E3(x3) x4 = self.pool(x3) x4 = self.E4(x4) x5 = self.pool(x4) x5 = self.E5(x5) return x1, x2, x3, x4, x5 class UNet(nn.Module): def __init__(self, in_channel=6, out_channel=3, inner_channel=32, norm_groups=32, channel_mults=(1, 2, 4, 8, 8), attn_res=(8), res_blocks=3, dropout=0, with_noise_level_emb=True, image_size=128, condition_ch=3): super().__init__() if with_noise_level_emb: noise_level_channel = inner_channel self.noise_level_mlp = nn.Sequential( PositionalEncoding(inner_channel), nn.Linear(inner_channel, inner_channel * 4), Swish(), nn.Linear(inner_channel * 4, inner_channel) ) else: noise_level_channel = None self.noise_level_mlp = None self.res_blocks = res_blocks num_mults = len(channel_mults) self.num_mults = num_mults pre_channel = inner_channel feat_channels = [pre_channel] now_res = image_size downs = [nn.Conv2d(in_channel, inner_channel, kernel_size=3, padding=1)] for ind in range(num_mults): is_last = (ind == num_mults - 1) use_attn = (now_res in attn_res) channel_mult = inner_channel * channel_mults[ind] for _ in range(0, res_blocks): downs.append(ResnetBlocWithAttn(pre_channel, channel_mult, noise_level_emb_dim=noise_level_channel, norm_groups=norm_groups, dropout=dropout, with_attn=use_attn, size=now_res)) feat_channels.append(channel_mult) pre_channel = channel_mult if not is_last: downs.append(Downsample(pre_channel)) feat_channels.append(pre_channel) now_res = now_res // 2 self.downs = nn.ModuleList(downs) self.mid = nn.ModuleList([ ResnetBlocWithAttn(pre_channel, pre_channel, noise_level_emb_dim=noise_level_channel, norm_groups=norm_groups, dropout=dropout, with_attn=True, size=now_res), ResnetBlocWithAttn(pre_channel, pre_channel, noise_level_emb_dim=noise_level_channel, norm_groups=norm_groups, dropout=dropout, with_attn=False, size=now_res) ]) ups = [] for ind in reversed(range(num_mults)): is_last = (ind < 1) use_attn = (now_res in attn_res) channel_mult = inner_channel * channel_mults[ind] for _ in range(0, res_blocks + 1): ups.append(ResnetBlocWithAttn(pre_channel + feat_channels.pop(), channel_mult, noise_level_emb_dim=noise_level_channel, norm_groups=norm_groups, dropout=dropout, with_attn=use_attn, size=now_res)) pre_channel = channel_mult if not is_last: ups.append(Upsample(pre_channel)) now_res = now_res * 2 self.ups = nn.ModuleList(ups) self.final_conv = Block(pre_channel, default(out_channel, in_channel), groups=norm_groups) self.condition = CPEN(inchannel=condition_ch) self.condition_ch = condition_ch def forward(self, x, time, img_s1=None, class_label=None, return_condition=False, t_ori=0): condition = x[:, :self.condition_ch, ...].clone() x = x[:, self.condition_ch:, ...] c1, c2, c3, c4, c5 = self.condition(condition) c_base = [c1, c2, c3, c4, c5] c = [] for i in range(len(c_base)): for _ in range(self.res_blocks): c.append(c_base[i]) t = self.noise_level_mlp(time) if exists(self.noise_level_mlp) else None feats = [] i = 0 for layer in self.downs: if isinstance(layer, ResnetBlocWithAttn): x = layer(x, t, c[i]) i += 1 else: x = layer(x) feats.append(x) for layer in self.mid: if isinstance(layer, ResnetBlocWithAttn): x = layer(x, t, c5) else: x = layer(x) c_base = [c5, c4, c3, c2, c1] c = [] for i in range(len(c_base)): for _ in range(self.res_blocks + 1): c.append(c_base[i]) i = 0 for layer in self.ups: if isinstance(layer, ResnetBlocWithAttn): x = layer(torch.cat((x, feats.pop()), dim=1), t, c[i]) i += 1 else: x = layer(x) if not return_condition: return self.final_conv(x) else: return self.final_conv(x), [c1, c2, c3, c4, c5] # ============================================================================ # E3Diff High-Resolution Inference # ============================================================================ class E3DiffHighRes: def __init__(self, device="cuda"): self.device = torch.device(device if torch.cuda.is_available() else "cpu") self.model = None self.image_size = 256 def load_model(self, weights_path=None): if weights_path is None: # Download from HuggingFace weights_path = hf_hub_download( repo_id="Dhenenjay/E3Diff-SAR2Optical", filename="I700000_E719_gen.pth" ) # Build UNet self.model = UNet( in_channel=3, out_channel=3, norm_groups=16, inner_channel=64, channel_mults=[1, 2, 4, 8, 16], attn_res=[], res_blocks=1, dropout=0, image_size=self.image_size, condition_ch=3 ).to(self.device) # Load weights state_dict = torch.load(weights_path, map_location=self.device, weights_only=False) # Filter only UNet weights unet_dict = {k.replace('denoise_fn.', ''): v for k, v in state_dict.items() if k.startswith('denoise_fn.')} self.model.load_state_dict(unet_dict, strict=False) self.model.eval() print(f"Model loaded on {self.device}") @torch.no_grad() def translate_tile(self, tile_tensor, num_steps=1): """Translate a single 256x256 tile.""" batch_size = tile_tensor.shape[0] # Initialize noise noise = torch.randn(batch_size, 3, self.image_size, self.image_size, device=self.device) # DDIM sampling total_timesteps = 1000 ts = torch.linspace(total_timesteps, 0, num_steps + 1).to(self.device).long() # Create beta schedule betas = torch.linspace(1e-6, 1e-2, total_timesteps, device=self.device) alphas = 1. - betas alphas_cumprod = torch.cumprod(alphas, dim=0) sqrt_alphas_cumprod_prev = torch.sqrt(torch.cat([torch.ones(1, device=self.device), alphas_cumprod])) x = noise for i in range(1, num_steps + 1): cur_t = ts[i - 1] - 1 prev_t = ts[i] - 1 noise_level = sqrt_alphas_cumprod_prev[cur_t].repeat(batch_size, 1) alpha_prod_t = alphas_cumprod[cur_t] alpha_prod_t_prev = alphas_cumprod[prev_t] if prev_t >= 0 else torch.tensor(1.0, device=self.device) beta_prod_t = 1 - alpha_prod_t # Model prediction model_input = torch.cat([tile_tensor, x], dim=1) model_output = self.model(model_input, noise_level) # DDIM update pred_original = (x - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5 pred_original = pred_original.clamp(-1, 1) sigma_2 = 0.8 * (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev) pred_dir = (1 - alpha_prod_t_prev - sigma_2) ** 0.5 * model_output if i < num_steps: noise = torch.randn_like(x) x = alpha_prod_t_prev ** 0.5 * pred_original + pred_dir + sigma_2 ** 0.5 * noise else: x = pred_original return x def create_blend_weights(self, tile_size, overlap): """Create smooth blending weights for seamless tiling.""" # Linear ramp for overlap regions ramp = np.linspace(0, 1, overlap) # Create 2D weight matrix weight = np.ones((tile_size, tile_size)) # Apply ramps to edges weight[:overlap, :] *= ramp[:, np.newaxis] # Top weight[-overlap:, :] *= ramp[::-1, np.newaxis] # Bottom weight[:, :overlap] *= ramp[np.newaxis, :] # Left weight[:, -overlap:] *= ramp[np.newaxis, ::-1] # Right return weight[:, :, np.newaxis] def translate_full_resolution(self, image, num_steps=1, overlap=64, progress_callback=None): """ Translate full resolution image using seamless tiling. """ # Convert to numpy if PIL if isinstance(image, Image.Image): if image.mode != 'RGB': image = image.convert('RGB') img_np = np.array(image).astype(np.float32) / 255.0 else: img_np = image h, w = img_np.shape[:2] tile_size = self.image_size step = tile_size - overlap # Pad image to ensure full coverage pad_h = (step - (h - overlap) % step) % step pad_w = (step - (w - overlap) % step) % step img_padded = np.pad(img_np, ((0, pad_h), (0, pad_w), (0, 0)), mode='reflect') h_pad, w_pad = img_padded.shape[:2] # Output arrays output = np.zeros((h_pad, w_pad, 3), dtype=np.float32) weights = np.zeros((h_pad, w_pad, 1), dtype=np.float32) # Blending weights blend_weight = self.create_blend_weights(tile_size, overlap) # Calculate tile positions y_positions = list(range(0, h_pad - tile_size + 1, step)) x_positions = list(range(0, w_pad - tile_size + 1, step)) total_tiles = len(y_positions) * len(x_positions) print(f"Processing {total_tiles} tiles ({len(x_positions)}x{len(y_positions)})...") tile_idx = 0 for y in y_positions: for x in x_positions: # Extract tile tile = img_padded[y:y+tile_size, x:x+tile_size] # Convert to tensor [-1, 1] tile_tensor = torch.from_numpy(tile).permute(2, 0, 1).unsqueeze(0) tile_tensor = tile_tensor * 2.0 - 1.0 tile_tensor = tile_tensor.to(self.device) # Translate result_tensor = self.translate_tile(tile_tensor, num_steps) # Convert back to numpy [0, 1] result = result_tensor.squeeze(0).permute(1, 2, 0).cpu().numpy() result = (result + 1.0) / 2.0 result = np.clip(result, 0, 1) # Add to output with blending output[y:y+tile_size, x:x+tile_size] += result * blend_weight weights[y:y+tile_size, x:x+tile_size] += blend_weight tile_idx += 1 if progress_callback: progress_callback(tile_idx / total_tiles) # Normalize by weights output = output / (weights + 1e-8) # Crop to original size output = output[:h, :w] return output def enhance_output(self, image, contrast=1.1, sharpness=1.15, color=1.1): """Apply professional post-processing.""" if isinstance(image, np.ndarray): image = Image.fromarray((image * 255).astype(np.uint8)) # Contrast image = ImageEnhance.Contrast(image).enhance(contrast) # Sharpness image = ImageEnhance.Sharpness(image).enhance(sharpness) # Color saturation image = ImageEnhance.Color(image).enhance(color) return image # ============================================================================ # Gradio Interface # ============================================================================ model = None def load_sar_image(filepath): """Load SAR image from various formats.""" try: import rasterio with rasterio.open(filepath) as src: data = src.read(1) if data.dtype in [np.float32, np.float64]: valid = data[np.isfinite(data)] if len(valid) > 0: p2, p98 = np.percentile(valid, [2, 98]) data = np.clip(data, p2, p98) data = ((data - p2) / (p98 - p2 + 1e-8) * 255).astype(np.uint8) elif data.dtype == np.uint16: p2, p98 = np.percentile(data, [2, 98]) data = np.clip(data, p2, p98) data = ((data - p2) / (p98 - p2 + 1e-8) * 255).astype(np.uint8) return Image.fromarray(data).convert('RGB') except: pass return Image.open(filepath).convert('RGB') def translate_sar(image, num_steps, overlap, enhance, progress=gr.Progress()): """Main translation function.""" global model if model is None: progress(0, desc="Loading model...") model = E3DiffHighRes() model.load_model() progress(0.1, desc="Processing image...") # Handle file upload if isinstance(image, str): image = load_sar_image(image) w, h = image.size print(f"Input size: {w}x{h}") # Progress callback def update_progress(p): progress(0.1 + 0.8 * p, desc=f"Translating... {int(p*100)}%") # Translate start = time.time() result = model.translate_full_resolution( image, num_steps=num_steps, overlap=overlap, progress_callback=update_progress ) elapsed = time.time() - start progress(0.9, desc="Post-processing...") # Convert to PIL result_pil = Image.fromarray((result * 255).astype(np.uint8)) # Enhance if requested if enhance: result_pil = model.enhance_output(result_pil) # Save as TIFF tiff_path = tempfile.mktemp(suffix='.tiff') result_pil.save(tiff_path, format='TIFF', compression='lzw') progress(1.0, desc="Complete!") info = f"Processed in {elapsed:.1f}s | Output: {result_pil.size[0]}x{result_pil.size[1]}" return result_pil, tiff_path, info # Create Gradio interface with gr.Blocks(title="E3Diff: SAR-to-Optical Translation", theme=gr.themes.Soft()) as demo: gr.Markdown(""" # 🛰️ E3Diff: High-Resolution SAR-to-Optical Translation **CVPR PBVS2025 Challenge Winner** | Upload any SAR image and get a photorealistic optical translation. - Supports full resolution processing with seamless tiling - Multiple quality levels (1-8 inference steps) - Professional post-processing - TIFF output for commercial use """) with gr.Row(): with gr.Column(): input_image = gr.Image(label="SAR Input", type="pil") with gr.Row(): num_steps = gr.Slider(1, 8, value=1, step=1, label="Quality Steps (1=fast, 4-8=high quality)") overlap = gr.Slider(16, 128, value=64, step=16, label="Tile Overlap (higher=smoother)") enhance = gr.Checkbox(value=True, label="Apply post-processing enhancement") submit_btn = gr.Button("🚀 Translate to Optical", variant="primary") with gr.Column(): output_image = gr.Image(label="Optical Output") output_file = gr.File(label="Download TIFF (full resolution)") info_text = gr.Textbox(label="Processing Info") submit_btn.click( fn=translate_sar, inputs=[input_image, num_steps, overlap, enhance], outputs=[output_image, output_file, info_text] ) gr.Markdown(""" --- **Tips for best results:** - For aerial/satellite SAR: Use steps=1-2 for speed, steps=4-8 for quality - For noisy SAR: Apply speckle filtering first (Lee or PPB filter) - The model works best with Sentinel-1 style imagery **Citation:** Qin et al., "Efficient End-to-End Diffusion Model for One-step SAR-to-Optical Translation", IEEE GRSL 2024 """) if __name__ == "__main__": demo.launch(server_name="0.0.0.0", server_port=7860)