""" Axion: SAR-to-Optical Translation - HuggingFace Space Fixed for ZeroGPU with lazy loading """ import os import numpy as np from PIL import Image, ImageEnhance import gradio as gr import tempfile import time print("[Axion] Starting app...") # ZeroGPU support try: import spaces GPU_AVAILABLE = True print("[Axion] ZeroGPU available") except ImportError: GPU_AVAILABLE = False spaces = None print("[Axion] Running without ZeroGPU") # Lazy imports for heavy modules _torch = None _model_modules = None def get_torch(): global _torch if _torch is None: print("[Axion] Importing torch...") import torch _torch = torch print(f"[Axion] PyTorch {torch.__version__} loaded") return _torch def get_model_modules(): global _model_modules if _model_modules is None: print("[Axion] Importing model modules...") from unet import UNet from diffusion import GaussianDiffusion _model_modules = (UNet, GaussianDiffusion) print("[Axion] Model modules loaded") return _model_modules def load_sar_image(filepath): """Load SAR image from various formats.""" try: import rasterio with rasterio.open(filepath) as src: data = src.read(1) if data.dtype in [np.float32, np.float64]: valid = data[np.isfinite(data)] if len(valid) > 0: p2, p98 = np.percentile(valid, [2, 98]) data = np.clip(data, p2, p98) data = ((data - p2) / (p98 - p2 + 1e-8) * 255).astype(np.uint8) elif data.dtype == np.uint16: p2, p98 = np.percentile(data, [2, 98]) data = np.clip(data, p2, p98) data = ((data - p2) / (p98 - p2 + 1e-8) * 255).astype(np.uint8) return Image.fromarray(data).convert('RGB') except: pass return Image.open(filepath).convert('RGB') def create_blend_weights(tile_size, overlap): """Create smooth blending weights for seamless output.""" ramp = np.linspace(0, 1, overlap) weight = np.ones((tile_size, tile_size)) weight[:overlap, :] *= ramp[:, np.newaxis] weight[-overlap:, :] *= ramp[::-1, np.newaxis] weight[:, :overlap] *= ramp[np.newaxis, :] weight[:, -overlap:] *= ramp[np.newaxis, ::-1] return weight[:, :, np.newaxis] def build_model(device): """Build and load the Axion model.""" torch = get_torch() UNet, GaussianDiffusion = get_model_modules() from huggingface_hub import hf_hub_download print("[Axion] Building model architecture...") image_size = 256 num_inference_steps = 1 # UNet configuration unet = UNet( in_channel=3, out_channel=3, norm_groups=16, inner_channel=64, channel_mults=[1, 2, 4, 8, 16], attn_res=[], res_blocks=1, dropout=0, image_size=image_size, condition_ch=3 ) # Diffusion wrapper schedule_opt = { 'schedule': 'linear', 'n_timestep': num_inference_steps, 'linear_start': 1e-6, 'linear_end': 1e-2, 'ddim': 1, 'lq_noiselevel': 0 } opt = { 'stage': 2, 'ddim_steps': num_inference_steps, 'model': { 'beta_schedule': { 'train': {'n_timestep': 1000}, 'val': schedule_opt } } } model = GaussianDiffusion( denoise_fn=unet, image_size=image_size, channels=3, loss_type='l1', conditional=True, schedule_opt=schedule_opt, xT_noise_r=0, seed=1, opt=opt ) model = model.to(device) # Load weights print("[Axion] Downloading weights...") weights_path = hf_hub_download( repo_id="Dhenenjay/Axion-S2O", filename="I700000_E719_gen.pth" ) print(f"[Axion] Loading weights from: {weights_path}") state_dict = torch.load(weights_path, map_location=device, weights_only=False) model.load_state_dict(state_dict, strict=False) model.eval() print("[Axion] Model ready!") return model def preprocess(image, device, image_size=256): """Preprocess input SAR image.""" torch = get_torch() if image.mode != 'RGB': image = image.convert('RGB') if image.size != (image_size, image_size): image = image.resize((image_size, image_size), Image.LANCZOS) img_np = np.array(image).astype(np.float32) / 255.0 img_tensor = torch.from_numpy(img_np).permute(2, 0, 1) img_tensor = img_tensor * 2.0 - 1.0 return img_tensor.unsqueeze(0).to(device) def postprocess(tensor): """Postprocess output tensor to PIL Image.""" torch = get_torch() tensor = tensor.squeeze(0).cpu() tensor = torch.clamp(tensor, -1, 1) tensor = (tensor + 1.0) / 2.0 img_np = (tensor.permute(1, 2, 0).numpy() * 255).astype(np.uint8) return Image.fromarray(img_np) def translate_tile(model, sar_pil, device, seed=42): """Translate a single tile.""" torch = get_torch() if seed is not None: torch.manual_seed(seed) np.random.seed(seed) sar_tensor = preprocess(sar_pil, device) model.set_new_noise_schedule( { 'schedule': 'linear', 'n_timestep': 1, 'linear_start': 1e-6, 'linear_end': 1e-2, 'ddim': 1, 'lq_noiselevel': 0 }, device, num_train_timesteps=1000 ) with torch.no_grad(): output, _ = model.super_resolution( sar_tensor, continous=False, seed=seed if seed is not None else 1, img_s1=sar_tensor ) return postprocess(output) def enhance_image(image, contrast=1.1, sharpness=1.2, color=1.1): """Professional post-processing.""" if isinstance(image, np.ndarray): image = Image.fromarray(image) image = ImageEnhance.Contrast(image).enhance(contrast) image = ImageEnhance.Sharpness(image).enhance(sharpness) image = ImageEnhance.Color(image).enhance(color) return image def process_image(image, model, device, overlap=64): """Process image at full resolution with seamless tiling.""" if isinstance(image, Image.Image): if image.mode != 'RGB': image = image.convert('RGB') img_np = np.array(image).astype(np.float32) / 255.0 else: img_np = image h, w = img_np.shape[:2] tile_size = 256 step = tile_size - overlap # Pad image pad_h = (step - (h - overlap) % step) % step pad_w = (step - (w - overlap) % step) % step img_padded = np.pad(img_np, ((0, pad_h), (0, pad_w), (0, 0)), mode='reflect') h_pad, w_pad = img_padded.shape[:2] # Output arrays output = np.zeros((h_pad, w_pad, 3), dtype=np.float32) weights = np.zeros((h_pad, w_pad, 1), dtype=np.float32) blend_weight = create_blend_weights(tile_size, overlap) # Calculate positions y_positions = list(range(0, h_pad - tile_size + 1, step)) x_positions = list(range(0, w_pad - tile_size + 1, step)) total_tiles = len(y_positions) * len(x_positions) print(f"[Axion] Processing {total_tiles} tiles ({len(x_positions)}x{len(y_positions)}) at {w}x{h}...") tile_idx = 0 for y in y_positions: for x in x_positions: # Extract tile tile = img_padded[y:y+tile_size, x:x+tile_size] tile_pil = Image.fromarray((tile * 255).astype(np.uint8)) # Translate result_pil = translate_tile(model, tile_pil, device, seed=42) result = np.array(result_pil).astype(np.float32) / 255.0 # Blend output[y:y+tile_size, x:x+tile_size] += result * blend_weight weights[y:y+tile_size, x:x+tile_size] += blend_weight tile_idx += 1 if tile_idx % 10 == 0 or tile_idx == total_tiles: print(f"[Axion] Tile {tile_idx}/{total_tiles}") # Normalize output = output / (weights + 1e-8) output = output[:h, :w] return (output * 255).astype(np.uint8) # Global model cache _cached_model = None def _translate_impl(file, overlap, enhance_output): """Main translation function - runs on GPU.""" global _cached_model if file is None: return None, None, "Please upload a SAR image" torch = get_torch() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"[Axion] Using device: {device}") # Load model (cached) if _cached_model is None: _cached_model = build_model(device) model = _cached_model # Load image filepath = file.name if hasattr(file, 'name') else file print(f"[Axion] Loading: {filepath}") image = load_sar_image(filepath) w, h = image.size print(f"[Axion] Input size: {w}x{h}") start = time.time() result = process_image(image, model, device, overlap=int(overlap)) elapsed = time.time() - start result_pil = Image.fromarray(result) if enhance_output: result_pil = enhance_image(result_pil) tiff_path = tempfile.mktemp(suffix='.tiff') result_pil.save(tiff_path, format='TIFF', compression='lzw') print(f"[Axion] Complete in {elapsed:.1f}s!") info = f"Processed in {elapsed:.1f}s | Output: {result_pil.size[0]}x{result_pil.size[1]}" return result_pil, tiff_path, info # Apply GPU decorator if GPU_AVAILABLE and spaces is not None: @spaces.GPU(duration=300) def translate_sar(file, overlap, enhance_output): return _translate_impl(file, overlap, enhance_output) else: translate_sar = _translate_impl print("[Axion] Building Gradio interface...") # Create Gradio interface with gr.Blocks(title="Axion - SAR to Optical") as demo: gr.HTML("""

SAR to Optical Image Translation

Transform radar imagery into crystal-clear optical views using our foundation model

""") with gr.Row(): with gr.Column(): input_file = gr.File(label="Upload SAR Image", file_types=[".tif", ".tiff", ".png", ".jpg", ".jpeg"]) gr.HTML("""
Input Guidelines:
• Use raw SAR imagery (single-band grayscale)
• VV polarization preferred, VH also supported
• Any resolution supported (processed in 256×256 tiles)
""") with gr.Row(): overlap = gr.Slider(16, 128, value=64, step=16, label="Tile Overlap") enhance = gr.Checkbox(value=True, label="Enhance Output") submit_btn = gr.Button("Translate", variant="primary") with gr.Column(): output_image = gr.Image(label="Optical Output") output_file = gr.File(label="Download") info_text = gr.Textbox(label="Info", show_label=False) submit_btn.click( fn=translate_sar, inputs=[input_file, overlap, enhance], outputs=[output_image, output_file, info_text] ) gr.HTML("""
Powered by Axion
""") print("[Axion] Launching app...") if __name__ == "__main__": demo.queue().launch(ssr_mode=False)