| """E3Diff: SAR-to-Optical Translation - HuggingFace Space."""
|
|
|
| import os
|
| import torch
|
| import numpy as np
|
| from PIL import Image, ImageEnhance
|
| import gradio as gr
|
| import tempfile
|
| import time
|
| from huggingface_hub import hf_hub_download
|
|
|
|
|
| from unet import UNet
|
| from diffusion import GaussianDiffusion
|
|
|
|
|
| try:
|
| import spaces
|
| GPU_AVAILABLE = True
|
| except ImportError:
|
| GPU_AVAILABLE = False
|
| spaces = None
|
|
|
|
|
| class E3DiffInference:
|
| """E3Diff Inference Pipeline - matches local implementation exactly."""
|
|
|
| def __init__(self, weights_path=None, device="cuda", num_inference_steps=1):
|
| self.device = torch.device(device if torch.cuda.is_available() else "cpu")
|
| self.image_size = 256
|
| self.num_inference_steps = num_inference_steps
|
|
|
| print(f"[E3Diff] Initializing on device: {self.device}")
|
| print(f"[E3Diff] Inference steps: {num_inference_steps}")
|
|
|
| self.model = self._build_model()
|
| self._load_weights(weights_path)
|
| self.model.eval()
|
| print("[E3Diff] Model ready!")
|
|
|
| def _build_model(self):
|
| """Build model - exact same config as local inference.py"""
|
| unet = UNet(
|
| in_channel=3,
|
| out_channel=3,
|
| norm_groups=16,
|
| inner_channel=64,
|
| channel_mults=[1, 2, 4, 8, 16],
|
| attn_res=[],
|
| res_blocks=1,
|
| dropout=0,
|
| image_size=self.image_size,
|
| condition_ch=3
|
| )
|
|
|
| schedule_opt = {
|
| 'schedule': 'linear',
|
| 'n_timestep': self.num_inference_steps,
|
| 'linear_start': 1e-6,
|
| 'linear_end': 1e-2,
|
| 'ddim': 1,
|
| 'lq_noiselevel': 0
|
| }
|
|
|
| opt = {
|
| 'stage': 2,
|
| 'ddim_steps': self.num_inference_steps,
|
| 'model': {
|
| 'beta_schedule': {
|
| 'train': {'n_timestep': 1000},
|
| 'val': schedule_opt
|
| }
|
| }
|
| }
|
|
|
| model = GaussianDiffusion(
|
| denoise_fn=unet,
|
| image_size=self.image_size,
|
| channels=3,
|
| loss_type='l1',
|
| conditional=True,
|
| schedule_opt=schedule_opt,
|
| xT_noise_r=0,
|
| seed=1,
|
| opt=opt
|
| )
|
|
|
| return model.to(self.device)
|
|
|
| def _load_weights(self, weights_path):
|
| """Load weights - same as local inference.py"""
|
| if weights_path is None:
|
| weights_path = hf_hub_download(
|
| repo_id="Dhenenjay/E3Diff-SAR2Optical",
|
| filename="I700000_E719_gen.pth"
|
| )
|
|
|
| print(f"[E3Diff] Loading weights from: {weights_path}")
|
| state_dict = torch.load(weights_path, map_location=self.device, weights_only=False)
|
| self.model.load_state_dict(state_dict, strict=False)
|
| print("[E3Diff] Weights loaded!")
|
|
|
| def preprocess(self, image):
|
| """Preprocess input image."""
|
| if image.mode != 'RGB':
|
| image = image.convert('RGB')
|
| if image.size != (self.image_size, self.image_size):
|
| image = image.resize((self.image_size, self.image_size), Image.LANCZOS)
|
|
|
| img_np = np.array(image).astype(np.float32) / 255.0
|
| img_tensor = torch.from_numpy(img_np).permute(2, 0, 1)
|
| img_tensor = img_tensor * 2.0 - 1.0
|
| return img_tensor.unsqueeze(0).to(self.device)
|
|
|
| def postprocess(self, tensor):
|
| """Postprocess output tensor."""
|
| tensor = tensor.squeeze(0).cpu()
|
| tensor = torch.clamp(tensor, -1, 1)
|
| tensor = (tensor + 1.0) / 2.0
|
| img_np = (tensor.permute(1, 2, 0).numpy() * 255).astype(np.uint8)
|
| return Image.fromarray(img_np)
|
|
|
| @torch.no_grad()
|
| def translate(self, sar_image, seed=42):
|
| """Translate SAR to optical - same as local inference.py"""
|
| if seed is not None:
|
| torch.manual_seed(seed)
|
| np.random.seed(seed)
|
|
|
| sar_tensor = self.preprocess(sar_image)
|
|
|
|
|
| self.model.set_new_noise_schedule(
|
| {
|
| 'schedule': 'linear',
|
| 'n_timestep': self.num_inference_steps,
|
| 'linear_start': 1e-6,
|
| 'linear_end': 1e-2,
|
| 'ddim': 1,
|
| 'lq_noiselevel': 0
|
| },
|
| self.device,
|
| num_train_timesteps=1000
|
| )
|
|
|
|
|
| output, _ = self.model.super_resolution(sar_tensor, continous=False, seed=seed, img_s1=sar_tensor)
|
| return self.postprocess(output)
|
|
|
|
|
| class HighResProcessor:
|
| """High resolution tiled processing."""
|
|
|
| def __init__(self, device="cuda"):
|
| self.device = device
|
| self.model = None
|
| self.tile_size = 256
|
| self.num_steps = None
|
|
|
| def load_model(self, num_steps=1):
|
| print(f"Loading E3Diff model with {num_steps} steps...")
|
| self.model = E3DiffInference(device=self.device, num_inference_steps=num_steps)
|
| self.num_steps = num_steps
|
|
|
| def create_blend_weights(self, tile_size, overlap):
|
| ramp = np.linspace(0, 1, overlap)
|
| weight = np.ones((tile_size, tile_size))
|
| weight[:overlap, :] *= ramp[:, np.newaxis]
|
| weight[-overlap:, :] *= ramp[::-1, np.newaxis]
|
| weight[:, :overlap] *= ramp[np.newaxis, :]
|
| weight[:, -overlap:] *= ramp[np.newaxis, ::-1]
|
| return weight[:, :, np.newaxis]
|
|
|
| def process(self, image, overlap=64, num_steps=1):
|
| if self.model is None or self.num_steps != num_steps:
|
| self.load_model(num_steps)
|
|
|
| if isinstance(image, Image.Image):
|
| if image.mode != 'RGB':
|
| image = image.convert('RGB')
|
| img_np = np.array(image).astype(np.float32) / 255.0
|
| else:
|
| img_np = image
|
|
|
| h, w = img_np.shape[:2]
|
| tile_size = self.tile_size
|
| step = tile_size - overlap
|
|
|
| pad_h = (step - (h - overlap) % step) % step
|
| pad_w = (step - (w - overlap) % step) % step
|
| img_padded = np.pad(img_np, ((0, pad_h), (0, pad_w), (0, 0)), mode='reflect')
|
|
|
| h_pad, w_pad = img_padded.shape[:2]
|
|
|
| output = np.zeros((h_pad, w_pad, 3), dtype=np.float32)
|
| weights = np.zeros((h_pad, w_pad, 1), dtype=np.float32)
|
| blend_weight = self.create_blend_weights(tile_size, overlap)
|
|
|
| y_positions = list(range(0, h_pad - tile_size + 1, step))
|
| x_positions = list(range(0, w_pad - tile_size + 1, step))
|
| total_tiles = len(y_positions) * len(x_positions)
|
|
|
| print(f"Processing {total_tiles} tiles at {w}x{h}...")
|
|
|
| tile_idx = 0
|
| for y in y_positions:
|
| for x in x_positions:
|
| tile = img_padded[y:y+tile_size, x:x+tile_size]
|
| tile_pil = Image.fromarray((tile * 255).astype(np.uint8))
|
|
|
| result_pil = self.model.translate(tile_pil, seed=42)
|
| result = np.array(result_pil).astype(np.float32) / 255.0
|
|
|
| output[y:y+tile_size, x:x+tile_size] += result * blend_weight
|
| weights[y:y+tile_size, x:x+tile_size] += blend_weight
|
|
|
| tile_idx += 1
|
| if tile_idx % 4 == 0 or tile_idx == total_tiles:
|
| print(f" Tile {tile_idx}/{total_tiles}")
|
|
|
| output = output / (weights + 1e-8)
|
| output = output[:h, :w]
|
|
|
| return (output * 255).astype(np.uint8)
|
|
|
| def enhance(self, image, contrast=1.1, sharpness=1.15, color=1.1):
|
| if isinstance(image, np.ndarray):
|
| image = Image.fromarray(image)
|
| image = ImageEnhance.Contrast(image).enhance(contrast)
|
| image = ImageEnhance.Sharpness(image).enhance(sharpness)
|
| image = ImageEnhance.Color(image).enhance(color)
|
| return image
|
|
|
|
|
|
|
| processor = None
|
|
|
|
|
| def load_sar_image(filepath):
|
| """Load SAR image from various formats."""
|
| try:
|
| import rasterio
|
| with rasterio.open(filepath) as src:
|
| data = src.read(1)
|
| if data.dtype in [np.float32, np.float64]:
|
| valid = data[np.isfinite(data)]
|
| if len(valid) > 0:
|
| p2, p98 = np.percentile(valid, [2, 98])
|
| data = np.clip(data, p2, p98)
|
| data = ((data - p2) / (p98 - p2 + 1e-8) * 255).astype(np.uint8)
|
| elif data.dtype == np.uint16:
|
| p2, p98 = np.percentile(data, [2, 98])
|
| data = np.clip(data, p2, p98)
|
| data = ((data - p2) / (p98 - p2 + 1e-8) * 255).astype(np.uint8)
|
| return Image.fromarray(data).convert('RGB')
|
| except:
|
| pass
|
|
|
| return Image.open(filepath).convert('RGB')
|
|
|
|
|
| def _translate_sar_impl(file, num_steps, overlap, enhance_output):
|
| """Main translation function."""
|
| global processor
|
|
|
| if file is None:
|
| return None, None, "Please upload a SAR image"
|
|
|
| if processor is None:
|
| processor = HighResProcessor()
|
|
|
| print("Processing SAR image...")
|
|
|
| filepath = file.name if hasattr(file, 'name') else file
|
| image = load_sar_image(filepath)
|
|
|
| w, h = image.size
|
| print(f"Input size: {w}x{h}")
|
|
|
| start = time.time()
|
| result = processor.process(image, overlap=int(overlap), num_steps=int(num_steps))
|
| elapsed = time.time() - start
|
|
|
| result_pil = Image.fromarray(result)
|
|
|
| if enhance_output:
|
| result_pil = processor.enhance(result_pil)
|
|
|
| tiff_path = tempfile.mktemp(suffix='.tiff')
|
| result_pil.save(tiff_path, format='TIFF', compression='lzw')
|
|
|
| print(f"Complete in {elapsed:.1f}s!")
|
|
|
| info = f"Processed in {elapsed:.1f}s | Output: {result_pil.size[0]}x{result_pil.size[1]}"
|
|
|
| return result_pil, tiff_path, info
|
|
|
|
|
|
|
| if GPU_AVAILABLE and spaces is not None:
|
| translate_sar = spaces.GPU(duration=300)(_translate_sar_impl)
|
| else:
|
| translate_sar = _translate_sar_impl
|
|
|
|
|
|
|
| with gr.Blocks(title="E3Diff: SAR-to-Optical Translation") as demo:
|
| gr.Markdown("""
|
| # 🛰️ E3Diff: High-Resolution SAR-to-Optical Translation
|
|
|
| **CVPR PBVS2025 Challenge Winner** | Upload any SAR image and get a photorealistic optical translation.
|
|
|
| - Supports full resolution processing with seamless tiling
|
| - Multiple quality levels (1-8 inference steps)
|
| - TIFF output for commercial use
|
| """)
|
|
|
| with gr.Row():
|
| with gr.Column():
|
| input_file = gr.File(label="SAR Input (TIFF, PNG, JPG)", file_types=[".tif", ".tiff", ".png", ".jpg", ".jpeg"])
|
|
|
| with gr.Row():
|
| num_steps = gr.Slider(1, 8, value=1, step=1, label="Quality Steps (1=fast, 8=best)")
|
| overlap = gr.Slider(16, 128, value=64, step=16, label="Tile Overlap")
|
|
|
| enhance = gr.Checkbox(value=True, label="Apply enhancement")
|
| submit_btn = gr.Button("🚀 Translate to Optical", variant="primary")
|
|
|
| with gr.Column():
|
| output_image = gr.Image(label="Optical Output")
|
| output_file = gr.File(label="Download TIFF")
|
| info_text = gr.Textbox(label="Processing Info")
|
|
|
| submit_btn.click(
|
| fn=translate_sar,
|
| inputs=[input_file, num_steps, overlap, enhance],
|
| outputs=[output_image, output_file, info_text]
|
| )
|
|
|
| gr.Markdown("""
|
| ---
|
| **Tips:** Use steps=1 for speed, steps=4-8 for quality. Works best with Sentinel-1 style SAR.
|
| """)
|
|
|
|
|
| if __name__ == "__main__":
|
| demo.queue().launch(ssr_mode=False)
|
|
|