Dhenenjay commited on
Commit
98d98f5
·
verified ·
1 Parent(s): 7f03507

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +284 -211
app.py CHANGED
@@ -4,9 +4,8 @@ HuggingFace Spaces Deployment
4
 
5
  Features:
6
  - Full resolution processing with seamless tiling
7
- - Multi-step inference for maximum quality
8
  - TIFF output support
9
- - Professional post-processing
10
  """
11
 
12
  import os
@@ -20,7 +19,7 @@ import gradio as gr
20
  from pathlib import Path
21
  import tempfile
22
  import time
23
- from tqdm import tqdm
24
  from huggingface_hub import hf_hub_download
25
 
26
  # ============================================================================
@@ -59,8 +58,7 @@ class SoftPool2d(nn.Module):
59
  return soft_pool2d(x, self.kernel_size, self.stride)
60
 
61
 
62
- # Monkey-patch SoftPool into the expected location
63
- import sys
64
  class SoftPoolModule:
65
  soft_pool2d = staticmethod(soft_pool2d)
66
  SoftPool2d = SoftPool2d
@@ -198,82 +196,66 @@ class ResnetBlocWithAttn(nn.Module):
198
  if with_attn:
199
  self.attn = SelfAttention(dim_out, norm_groups=norm_groups)
200
 
201
- def forward(self, x, time_emb, c, t=0, save_flag=False, file_i=0):
202
  x = self.res_block(x, time_emb, c)
203
  if self.with_attn:
204
- x = self.attn(x, t=t, save_flag=save_flag, file_num=file_i)
205
  return x
206
 
207
 
208
- class ResBlock_normal(nn.Module):
209
- def __init__(self, dim, dim_out, dropout=0, norm_groups=32):
210
- super().__init__()
211
- self.block1 = Block(dim, dim_out, groups=norm_groups)
212
- self.block2 = Block(dim_out, dim_out, groups=norm_groups, dropout=dropout)
213
- self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
214
-
215
- def forward(self, x):
216
- h = self.block1(x)
217
- h = self.block2(h)
218
- return h + self.res_conv(x)
219
-
220
-
221
  class CPEN(nn.Module):
222
- def __init__(self, inchannel=1):
223
  super(CPEN, self).__init__()
224
- self.pool = SoftPool2d(kernel_size=(2,2), stride=(2,2))
225
- self.E1 = nn.Sequential(nn.Conv2d(inchannel, 64, kernel_size=3, padding=1), Swish())
226
- self.E2 = nn.Sequential(ResBlock_normal(64, 128, dropout=0, norm_groups=16), ResBlock_normal(128, 128, dropout=0, norm_groups=16))
227
- self.E3 = nn.Sequential(ResBlock_normal(128, 256, dropout=0, norm_groups=16), ResBlock_normal(256, 256, dropout=0, norm_groups=16))
228
- self.E4 = nn.Sequential(ResBlock_normal(256, 512, dropout=0, norm_groups=16), ResBlock_normal(512, 512, dropout=0, norm_groups=16))
229
- self.E5 = nn.Sequential(ResBlock_normal(512, 512, dropout=0, norm_groups=16), ResBlock_normal(512, 1024, dropout=0, norm_groups=16))
 
 
 
 
 
230
 
231
  def forward(self, x):
232
- x1 = self.E1(x)
233
- x2 = self.pool(x1)
234
- x2 = self.E2(x2)
235
- x3 = self.pool(x2)
236
- x3 = self.E3(x3)
237
- x4 = self.pool(x3)
238
- x4 = self.E4(x4)
239
- x5 = self.pool(x4)
240
- x5 = self.E5(x5)
241
- return x1, x2, x3, x4, x5
242
 
243
 
244
  class UNet(nn.Module):
245
  def __init__(self, in_channel=6, out_channel=3, inner_channel=32, norm_groups=32,
246
- channel_mults=(1, 2, 4, 8, 8), attn_res=(8), res_blocks=3, dropout=0,
247
  with_noise_level_emb=True, image_size=128, condition_ch=3):
248
  super().__init__()
249
 
250
- if with_noise_level_emb:
251
- noise_level_channel = inner_channel
252
- self.noise_level_mlp = nn.Sequential(
253
- PositionalEncoding(inner_channel),
254
- nn.Linear(inner_channel, inner_channel * 4),
255
- Swish(),
256
- nn.Linear(inner_channel * 4, inner_channel)
257
- )
258
- else:
259
- noise_level_channel = None
260
- self.noise_level_mlp = None
261
-
262
  self.res_blocks = res_blocks
 
 
 
 
 
 
 
 
263
  num_mults = len(channel_mults)
264
- self.num_mults = num_mults
265
  pre_channel = inner_channel
266
  feat_channels = [pre_channel]
267
  now_res = image_size
268
-
269
  downs = [nn.Conv2d(in_channel, inner_channel, kernel_size=3, padding=1)]
270
  for ind in range(num_mults):
271
  is_last = (ind == num_mults - 1)
272
  use_attn = (now_res in attn_res)
273
  channel_mult = inner_channel * channel_mults[ind]
274
  for _ in range(0, res_blocks):
275
- downs.append(ResnetBlocWithAttn(pre_channel, channel_mult, noise_level_emb_dim=noise_level_channel,
276
- norm_groups=norm_groups, dropout=dropout, with_attn=use_attn, size=now_res))
277
  feat_channels.append(channel_mult)
278
  pre_channel = channel_mult
279
  if not is_last:
@@ -283,7 +265,7 @@ class UNet(nn.Module):
283
  self.downs = nn.ModuleList(downs)
284
 
285
  self.mid = nn.ModuleList([
286
- ResnetBlocWithAttn(pre_channel, pre_channel, noise_level_emb_dim=noise_level_channel,
287
  norm_groups=norm_groups, dropout=dropout, with_attn=True, size=now_res),
288
  ResnetBlocWithAttn(pre_channel, pre_channel, noise_level_emb_dim=noise_level_channel,
289
  norm_groups=norm_groups, dropout=dropout, with_attn=False, size=now_res)
@@ -359,25 +341,135 @@ class UNet(nn.Module):
359
 
360
 
361
  # ============================================================================
362
- # E3Diff High-Resolution Inference
363
  # ============================================================================
364
 
365
- class E3DiffHighRes:
366
- def __init__(self, device="cuda"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
367
  self.device = torch.device(device if torch.cuda.is_available() else "cpu")
368
- self.model = None
369
  self.image_size = 256
 
370
 
371
- def load_model(self, weights_path=None):
372
- if weights_path is None:
373
- # Download from HuggingFace
374
- weights_path = hf_hub_download(
375
- repo_id="Dhenenjay/E3Diff-SAR2Optical",
376
- filename="I700000_E719_gen.pth"
377
- )
378
 
379
- # Build UNet
380
- self.model = UNet(
 
 
 
 
 
381
  in_channel=3,
382
  out_channel=3,
383
  norm_groups=16,
@@ -388,88 +480,115 @@ class E3DiffHighRes:
388
  dropout=0,
389
  image_size=self.image_size,
390
  condition_ch=3
391
- ).to(self.device)
392
-
393
- # Load weights
394
- state_dict = torch.load(weights_path, map_location=self.device, weights_only=False)
395
 
396
- # Filter only UNet weights
397
- unet_dict = {k.replace('denoise_fn.', ''): v for k, v in state_dict.items()
398
- if k.startswith('denoise_fn.')}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
399
 
400
- self.model.load_state_dict(unet_dict, strict=False)
401
- self.model.eval()
402
- print(f"Model loaded on {self.device}")
 
 
 
 
 
403
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
404
  @torch.no_grad()
405
- def translate_tile(self, tile_tensor, num_steps=1):
406
- """Translate a single 256x256 tile."""
407
- batch_size = tile_tensor.shape[0]
408
-
409
- # Initialize noise
410
- noise = torch.randn(batch_size, 3, self.image_size, self.image_size, device=self.device)
411
-
412
- # DDIM sampling
413
- total_timesteps = 1000
414
- ts = torch.linspace(total_timesteps, 0, num_steps + 1).to(self.device).long()
415
-
416
- # Create beta schedule
417
- betas = torch.linspace(1e-6, 1e-2, total_timesteps, device=self.device)
418
- alphas = 1. - betas
419
- alphas_cumprod = torch.cumprod(alphas, dim=0)
420
- sqrt_alphas_cumprod_prev = torch.sqrt(torch.cat([torch.ones(1, device=self.device), alphas_cumprod]))
421
-
422
- x = noise
423
- for i in range(1, num_steps + 1):
424
- cur_t = ts[i - 1] - 1
425
- prev_t = ts[i] - 1
426
-
427
- noise_level = sqrt_alphas_cumprod_prev[cur_t].repeat(batch_size, 1)
428
-
429
- alpha_prod_t = alphas_cumprod[cur_t]
430
- alpha_prod_t_prev = alphas_cumprod[prev_t] if prev_t >= 0 else torch.tensor(1.0, device=self.device)
431
- beta_prod_t = 1 - alpha_prod_t
432
-
433
- # Model prediction
434
- model_input = torch.cat([tile_tensor, x], dim=1)
435
- model_output = self.model(model_input, noise_level)
436
-
437
- # DDIM update
438
- pred_original = (x - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5
439
- pred_original = pred_original.clamp(-1, 1)
440
-
441
- sigma_2 = 0.8 * (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
442
- pred_dir = (1 - alpha_prod_t_prev - sigma_2) ** 0.5 * model_output
443
-
444
- if i < num_steps:
445
- noise = torch.randn_like(x)
446
- x = alpha_prod_t_prev ** 0.5 * pred_original + pred_dir + sigma_2 ** 0.5 * noise
447
- else:
448
- x = pred_original
449
 
450
- return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
451
 
452
  def create_blend_weights(self, tile_size, overlap):
453
- """Create smooth blending weights for seamless tiling."""
454
- # Linear ramp for overlap regions
455
  ramp = np.linspace(0, 1, overlap)
456
-
457
- # Create 2D weight matrix
458
  weight = np.ones((tile_size, tile_size))
459
-
460
- # Apply ramps to edges
461
- weight[:overlap, :] *= ramp[:, np.newaxis] # Top
462
- weight[-overlap:, :] *= ramp[::-1, np.newaxis] # Bottom
463
- weight[:, :overlap] *= ramp[np.newaxis, :] # Left
464
- weight[:, -overlap:] *= ramp[np.newaxis, ::-1] # Right
465
-
466
  return weight[:, :, np.newaxis]
467
 
468
- def translate_full_resolution(self, image, num_steps=1, overlap=64, progress_callback=None):
469
- """
470
- Translate full resolution image using seamless tiling.
471
- """
472
- # Convert to numpy if PIL
473
  if isinstance(image, Image.Image):
474
  if image.mode != 'RGB':
475
  image = image.convert('RGB')
@@ -478,77 +597,51 @@ class E3DiffHighRes:
478
  img_np = image
479
 
480
  h, w = img_np.shape[:2]
481
- tile_size = self.image_size
482
  step = tile_size - overlap
483
 
484
- # Pad image to ensure full coverage
485
  pad_h = (step - (h - overlap) % step) % step
486
  pad_w = (step - (w - overlap) % step) % step
487
  img_padded = np.pad(img_np, ((0, pad_h), (0, pad_w), (0, 0)), mode='reflect')
488
 
489
  h_pad, w_pad = img_padded.shape[:2]
490
 
491
- # Output arrays
492
  output = np.zeros((h_pad, w_pad, 3), dtype=np.float32)
493
  weights = np.zeros((h_pad, w_pad, 1), dtype=np.float32)
494
-
495
- # Blending weights
496
  blend_weight = self.create_blend_weights(tile_size, overlap)
497
 
498
- # Calculate tile positions
499
  y_positions = list(range(0, h_pad - tile_size + 1, step))
500
  x_positions = list(range(0, w_pad - tile_size + 1, step))
501
  total_tiles = len(y_positions) * len(x_positions)
502
 
503
- print(f"Processing {total_tiles} tiles ({len(x_positions)}x{len(y_positions)})...")
504
 
505
  tile_idx = 0
506
  for y in y_positions:
507
  for x in x_positions:
508
- # Extract tile
509
  tile = img_padded[y:y+tile_size, x:x+tile_size]
 
510
 
511
- # Convert to tensor [-1, 1]
512
- tile_tensor = torch.from_numpy(tile).permute(2, 0, 1).unsqueeze(0)
513
- tile_tensor = tile_tensor * 2.0 - 1.0
514
- tile_tensor = tile_tensor.to(self.device)
515
-
516
- # Translate
517
- result_tensor = self.translate_tile(tile_tensor, num_steps)
518
 
519
- # Convert back to numpy [0, 1]
520
- result = result_tensor.squeeze(0).permute(1, 2, 0).cpu().numpy()
521
- result = (result + 1.0) / 2.0
522
- result = np.clip(result, 0, 1)
523
-
524
- # Add to output with blending
525
  output[y:y+tile_size, x:x+tile_size] += result * blend_weight
526
  weights[y:y+tile_size, x:x+tile_size] += blend_weight
527
 
528
  tile_idx += 1
529
- if progress_callback:
530
- progress_callback(tile_idx / total_tiles)
531
 
532
- # Normalize by weights
533
  output = output / (weights + 1e-8)
534
-
535
- # Crop to original size
536
  output = output[:h, :w]
537
 
538
- return output
539
 
540
- def enhance_output(self, image, contrast=1.1, sharpness=1.15, color=1.1):
541
- """Apply professional post-processing."""
542
  if isinstance(image, np.ndarray):
543
- image = Image.fromarray((image * 255).astype(np.uint8))
544
-
545
- # Contrast
546
  image = ImageEnhance.Contrast(image).enhance(contrast)
547
- # Sharpness
548
  image = ImageEnhance.Sharpness(image).enhance(sharpness)
549
- # Color saturation
550
  image = ImageEnhance.Color(image).enhance(color)
551
-
552
  return image
553
 
554
 
@@ -556,7 +649,7 @@ class E3DiffHighRes:
556
  # Gradio Interface
557
  # ============================================================================
558
 
559
- model = None
560
 
561
  def load_sar_image(filepath):
562
  """Load SAR image from various formats."""
@@ -581,58 +674,44 @@ def load_sar_image(filepath):
581
  return Image.open(filepath).convert('RGB')
582
 
583
 
584
- def translate_sar(file, num_steps, overlap, enhance):
585
  """Main translation function."""
586
- global model
587
 
588
  if file is None:
589
  return None, None, "Please upload a SAR image"
590
 
591
- if model is None:
592
- print("Loading model...")
593
- model = E3DiffHighRes()
594
- model.load_model()
595
 
596
- print("Processing image...")
597
 
598
- # Handle file upload - get the filepath
599
  filepath = file.name if hasattr(file, 'name') else file
600
  image = load_sar_image(filepath)
601
 
602
  w, h = image.size
603
  print(f"Input size: {w}x{h}")
604
 
605
- # Translate
606
  start = time.time()
607
- result = model.translate_full_resolution(
608
- image,
609
- num_steps=num_steps,
610
- overlap=overlap,
611
- progress_callback=None
612
- )
613
  elapsed = time.time() - start
614
 
615
- print("Post-processing...")
616
-
617
- # Convert to PIL
618
- result_pil = Image.fromarray((result * 255).astype(np.uint8))
619
 
620
- # Enhance if requested
621
- if enhance:
622
- result_pil = model.enhance_output(result_pil)
623
 
624
- # Save as TIFF
625
  tiff_path = tempfile.mktemp(suffix='.tiff')
626
  result_pil.save(tiff_path, format='TIFF', compression='lzw')
627
 
628
- print("Complete!")
629
 
630
  info = f"Processed in {elapsed:.1f}s | Output: {result_pil.size[0]}x{result_pil.size[1]}"
631
 
632
  return result_pil, tiff_path, info
633
 
634
 
635
- # Create Gradio interface
636
  with gr.Blocks(title="E3Diff: SAR-to-Optical Translation") as demo:
637
  gr.Markdown("""
638
  # 🛰️ E3Diff: High-Resolution SAR-to-Optical Translation
@@ -641,7 +720,6 @@ with gr.Blocks(title="E3Diff: SAR-to-Optical Translation") as demo:
641
 
642
  - Supports full resolution processing with seamless tiling
643
  - Multiple quality levels (1-8 inference steps)
644
- - Professional post-processing
645
  - TIFF output for commercial use
646
  """)
647
 
@@ -650,16 +728,16 @@ with gr.Blocks(title="E3Diff: SAR-to-Optical Translation") as demo:
650
  input_file = gr.File(label="SAR Input (TIFF, PNG, JPG supported)", file_types=[".tif", ".tiff", ".png", ".jpg", ".jpeg"])
651
 
652
  with gr.Row():
653
- num_steps = gr.Slider(1, 8, value=1, step=1, label="Quality Steps (1=fast, 4-8=high quality)")
654
- overlap = gr.Slider(16, 128, value=64, step=16, label="Tile Overlap (higher=smoother)")
655
 
656
- enhance = gr.Checkbox(value=True, label="Apply post-processing enhancement")
657
 
658
  submit_btn = gr.Button("🚀 Translate to Optical", variant="primary")
659
 
660
  with gr.Column():
661
  output_image = gr.Image(label="Optical Output")
662
- output_file = gr.File(label="Download TIFF (full resolution)")
663
  info_text = gr.Textbox(label="Processing Info")
664
 
665
  submit_btn.click(
@@ -670,12 +748,7 @@ with gr.Blocks(title="E3Diff: SAR-to-Optical Translation") as demo:
670
 
671
  gr.Markdown("""
672
  ---
673
- **Tips for best results:**
674
- - For aerial/satellite SAR: Use steps=1-2 for speed, steps=4-8 for quality
675
- - For noisy SAR: Apply speckle filtering first (Lee or PPB filter)
676
- - The model works best with Sentinel-1 style imagery
677
-
678
- **Citation:** Qin et al., "Efficient End-to-End Diffusion Model for One-step SAR-to-Optical Translation", IEEE GRSL 2024
679
  """)
680
 
681
 
 
4
 
5
  Features:
6
  - Full resolution processing with seamless tiling
7
+ - Proper diffusion sampling (matching local inference)
8
  - TIFF output support
 
9
  """
10
 
11
  import os
 
19
  from pathlib import Path
20
  import tempfile
21
  import time
22
+ from functools import partial
23
  from huggingface_hub import hf_hub_download
24
 
25
  # ============================================================================
 
58
  return soft_pool2d(x, self.kernel_size, self.stride)
59
 
60
 
61
+ # Monkey-patch SoftPool
 
62
  class SoftPoolModule:
63
  soft_pool2d = staticmethod(soft_pool2d)
64
  SoftPool2d = SoftPool2d
 
196
  if with_attn:
197
  self.attn = SelfAttention(dim_out, norm_groups=norm_groups)
198
 
199
+ def forward(self, x, time_emb, c):
200
  x = self.res_block(x, time_emb, c)
201
  if self.with_attn:
202
+ x = self.attn(x, time_emb)
203
  return x
204
 
205
 
206
+ # CPEN Condition Encoder
 
 
 
 
 
 
 
 
 
 
 
 
207
  class CPEN(nn.Module):
208
+ def __init__(self, inchannel=3):
209
  super(CPEN, self).__init__()
210
+ from SoftPool import SoftPool2d
211
+
212
+ self.conv1 = nn.Conv2d(inchannel, 64, 3, 1, 1)
213
+ self.pool1 = SoftPool2d(kernel_size=(2, 2), stride=(2, 2))
214
+ self.conv2 = nn.Conv2d(64, 128, 3, 1, 1)
215
+ self.pool2 = SoftPool2d(kernel_size=(2, 2), stride=(2, 2))
216
+ self.conv3 = nn.Conv2d(128, 256, 3, 1, 1)
217
+ self.pool3 = SoftPool2d(kernel_size=(2, 2), stride=(2, 2))
218
+ self.conv4 = nn.Conv2d(256, 512, 3, 1, 1)
219
+ self.pool4 = SoftPool2d(kernel_size=(2, 2), stride=(2, 2))
220
+ self.conv5 = nn.Conv2d(512, 1024, 3, 1, 1)
221
 
222
  def forward(self, x):
223
+ c1 = self.pool1(F.leaky_relu(self.conv1(x)))
224
+ c2 = self.pool2(F.leaky_relu(self.conv2(c1)))
225
+ c3 = self.pool3(F.leaky_relu(self.conv3(c2)))
226
+ c4 = self.pool4(F.leaky_relu(self.conv4(c3)))
227
+ c5 = F.leaky_relu(self.conv5(c4))
228
+ return c1, c2, c3, c4, c5
 
 
 
 
229
 
230
 
231
  class UNet(nn.Module):
232
  def __init__(self, in_channel=6, out_channel=3, inner_channel=32, norm_groups=32,
233
+ channel_mults=(1, 2, 4, 8, 8), attn_res=(8,), res_blocks=3, dropout=0,
234
  with_noise_level_emb=True, image_size=128, condition_ch=3):
235
  super().__init__()
236
 
 
 
 
 
 
 
 
 
 
 
 
 
237
  self.res_blocks = res_blocks
238
+ noise_level_channel = inner_channel
239
+ self.noise_level_mlp = nn.Sequential(
240
+ PositionalEncoding(inner_channel),
241
+ nn.Linear(inner_channel, inner_channel * 4),
242
+ Swish(),
243
+ nn.Linear(inner_channel * 4, inner_channel)
244
+ ) if with_noise_level_emb else None
245
+
246
  num_mults = len(channel_mults)
 
247
  pre_channel = inner_channel
248
  feat_channels = [pre_channel]
249
  now_res = image_size
250
+
251
  downs = [nn.Conv2d(in_channel, inner_channel, kernel_size=3, padding=1)]
252
  for ind in range(num_mults):
253
  is_last = (ind == num_mults - 1)
254
  use_attn = (now_res in attn_res)
255
  channel_mult = inner_channel * channel_mults[ind]
256
  for _ in range(0, res_blocks):
257
+ downs.append(ResnetBlocWithAttn(pre_channel, channel_mult, noise_level_emb_dim=noise_level_channel,
258
+ norm_groups=norm_groups, dropout=dropout, with_attn=use_attn, size=now_res))
259
  feat_channels.append(channel_mult)
260
  pre_channel = channel_mult
261
  if not is_last:
 
265
  self.downs = nn.ModuleList(downs)
266
 
267
  self.mid = nn.ModuleList([
268
+ ResnetBlocWithAttn(pre_channel, pre_channel, noise_level_emb_dim=noise_level_channel,
269
  norm_groups=norm_groups, dropout=dropout, with_attn=True, size=now_res),
270
  ResnetBlocWithAttn(pre_channel, pre_channel, noise_level_emb_dim=noise_level_channel,
271
  norm_groups=norm_groups, dropout=dropout, with_attn=False, size=now_res)
 
341
 
342
 
343
  # ============================================================================
344
+ # GaussianDiffusion - Proper DDIM Sampling
345
  # ============================================================================
346
 
347
+ def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2):
348
+ if schedule == 'linear':
349
+ betas = np.linspace(linear_start, linear_end, n_timestep, dtype=np.float64)
350
+ else:
351
+ raise NotImplementedError(schedule)
352
+ return betas
353
+
354
+
355
+ class GaussianDiffusion(nn.Module):
356
+ def __init__(self, denoise_fn, image_size, channels=3, schedule_opt=None, opt=None):
357
+ super().__init__()
358
+ self.channels = channels
359
+ self.image_size = image_size
360
+ self.denoise_fn = denoise_fn
361
+ self.opt = opt
362
+ self.ddim = schedule_opt.get('ddim', 1) if schedule_opt else 1
363
+
364
+ def set_new_noise_schedule(self, schedule_opt, device, num_train_timesteps=1000):
365
+ self.ddim = schedule_opt['ddim']
366
+ self.num_train_timesteps = num_train_timesteps
367
+ to_torch = partial(torch.tensor, dtype=torch.float32, device=device)
368
+
369
+ betas = make_beta_schedule(
370
+ schedule=schedule_opt['schedule'],
371
+ n_timestep=num_train_timesteps,
372
+ linear_start=schedule_opt['linear_start'],
373
+ linear_end=schedule_opt['linear_end']
374
+ )
375
+
376
+ alphas = 1. - betas
377
+ alphas_cumprod = np.cumprod(alphas, axis=0)
378
+ self.sqrt_alphas_cumprod_prev = np.sqrt(np.append(1., alphas_cumprod))
379
+
380
+ self.num_timesteps = int(betas.shape[0])
381
+ self.register_buffer('betas', to_torch(betas))
382
+ self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
383
+
384
+ self.ddim_num_steps = schedule_opt['n_timestep']
385
+ print(f'DDIM sampling steps: {self.ddim_num_steps}')
386
+
387
+ def ddim_sample(self, condition_x, img_or_shape, device, seed=1):
388
+ """DDIM sampling - matches the original E3Diff implementation."""
389
+ eta = 0.8 # ddim_sampling_eta for linear schedule
390
+
391
+ batch = img_or_shape[0]
392
+ total_timesteps = self.num_train_timesteps
393
+ sampling_timesteps = self.ddim_num_steps
394
+
395
+ ts = torch.linspace(total_timesteps, 0, sampling_timesteps + 1).to(device).long()
396
+ x = torch.randn(img_or_shape, device=device)
397
+ batch_size = x.shape[0]
398
+
399
+ imgs = [x]
400
+ img_onestep = [condition_x[:, :self.channels, ...]]
401
+
402
+ for i in range(1, sampling_timesteps + 1):
403
+ cur_t = ts[i - 1] - 1
404
+ prev_t = ts[i] - 1
405
+
406
+ noise_level = torch.FloatTensor(
407
+ [self.sqrt_alphas_cumprod_prev[cur_t.item()]]
408
+ ).repeat(batch_size, 1).to(device)
409
+
410
+ alpha_prod_t = self.alphas_cumprod[cur_t]
411
+ alpha_prod_t_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else torch.tensor(1.0, device=device)
412
+ beta_prod_t = 1 - alpha_prod_t
413
+
414
+ # Model prediction
415
+ model_output = self.denoise_fn(torch.cat([condition_x, x], dim=1), noise_level)
416
+
417
+ # Compute sigma
418
+ sigma_2 = eta * (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
419
+ noise = torch.randn_like(x)
420
+
421
+ # Predict original sample
422
+ pred_original_sample = (x - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5
423
+ pred_original_sample = pred_original_sample.clamp(-1, 1)
424
+
425
+ pred_sample_direction = (1 - alpha_prod_t_prev - sigma_2) ** 0.5 * model_output
426
+
427
+ x = alpha_prod_t_prev ** 0.5 * pred_original_sample + pred_sample_direction + sigma_2 ** 0.5 * noise
428
+
429
+ imgs.append(x)
430
+ img_onestep.append(pred_original_sample)
431
+
432
+ imgs = torch.cat(imgs, dim=0)
433
+ img_onestep = torch.cat(img_onestep, dim=0)
434
+
435
+ return imgs, img_onestep
436
+
437
+ @torch.no_grad()
438
+ def super_resolution(self, x_in, continous=False, seed=1, img_s1=None):
439
+ """Main inference method."""
440
+ device = self.betas.device
441
+ x = x_in
442
+ shape = (x.shape[0], self.channels, x.shape[-2], x.shape[-1])
443
+
444
+ self.ddim_num_steps = self.opt['ddim_steps']
445
+ ret_img, img_onestep = self.ddim_sample(condition_x=x, img_or_shape=shape, device=device, seed=seed)
446
+
447
+ if continous:
448
+ return ret_img, img_onestep
449
+ else:
450
+ return ret_img[-x_in.shape[0]:], img_onestep
451
+
452
+
453
+ # ============================================================================
454
+ # E3Diff Inference Class
455
+ # ============================================================================
456
+
457
+ class E3DiffInference:
458
+ def __init__(self, weights_path=None, device="cuda", num_inference_steps=1):
459
  self.device = torch.device(device if torch.cuda.is_available() else "cpu")
 
460
  self.image_size = 256
461
+ self.num_inference_steps = num_inference_steps
462
 
463
+ print(f"[E3Diff] Initializing on device: {self.device}")
464
+ print(f"[E3Diff] Inference steps: {num_inference_steps}")
 
 
 
 
 
465
 
466
+ self.model = self._build_model()
467
+ self._load_weights(weights_path)
468
+ self.model.eval()
469
+ print("[E3Diff] Model ready!")
470
+
471
+ def _build_model(self):
472
+ unet = UNet(
473
  in_channel=3,
474
  out_channel=3,
475
  norm_groups=16,
 
480
  dropout=0,
481
  image_size=self.image_size,
482
  condition_ch=3
483
+ )
 
 
 
484
 
485
+ schedule_opt = {
486
+ 'schedule': 'linear',
487
+ 'n_timestep': self.num_inference_steps,
488
+ 'linear_start': 1e-6,
489
+ 'linear_end': 1e-2,
490
+ 'ddim': 1,
491
+ 'lq_noiselevel': 0
492
+ }
493
+
494
+ opt = {
495
+ 'stage': 2,
496
+ 'ddim_steps': self.num_inference_steps,
497
+ }
498
+
499
+ model = GaussianDiffusion(
500
+ denoise_fn=unet,
501
+ image_size=self.image_size,
502
+ channels=3,
503
+ schedule_opt=schedule_opt,
504
+ opt=opt
505
+ )
506
 
507
+ return model.to(self.device)
508
+
509
+ def _load_weights(self, weights_path):
510
+ if weights_path is None:
511
+ weights_path = hf_hub_download(
512
+ repo_id="Dhenenjay/E3Diff-SAR2Optical",
513
+ filename="I700000_E719_gen.pth"
514
+ )
515
 
516
+ print(f"[E3Diff] Loading weights from: {weights_path}")
517
+ state_dict = torch.load(weights_path, map_location=self.device, weights_only=False)
518
+ self.model.load_state_dict(state_dict, strict=False)
519
+ print("[E3Diff] Weights loaded!")
520
+
521
+ def preprocess(self, image):
522
+ if image.mode != 'RGB':
523
+ image = image.convert('RGB')
524
+ if image.size != (self.image_size, self.image_size):
525
+ image = image.resize((self.image_size, self.image_size), Image.LANCZOS)
526
+
527
+ img_np = np.array(image).astype(np.float32) / 255.0
528
+ img_tensor = torch.from_numpy(img_np).permute(2, 0, 1)
529
+ img_tensor = img_tensor * 2.0 - 1.0
530
+ return img_tensor.unsqueeze(0).to(self.device)
531
+
532
+ def postprocess(self, tensor):
533
+ tensor = tensor.squeeze(0).cpu()
534
+ tensor = torch.clamp(tensor, -1, 1)
535
+ tensor = (tensor + 1.0) / 2.0
536
+ img_np = (tensor.permute(1, 2, 0).numpy() * 255).astype(np.uint8)
537
+ return Image.fromarray(img_np)
538
+
539
  @torch.no_grad()
540
+ def translate(self, sar_image, seed=42):
541
+ if seed is not None:
542
+ torch.manual_seed(seed)
543
+ np.random.seed(seed)
544
+
545
+ sar_tensor = self.preprocess(sar_image)
546
+
547
+ self.model.set_new_noise_schedule(
548
+ {
549
+ 'schedule': 'linear',
550
+ 'n_timestep': self.num_inference_steps,
551
+ 'linear_start': 1e-6,
552
+ 'linear_end': 1e-2,
553
+ 'ddim': 1,
554
+ 'lq_noiselevel': 0
555
+ },
556
+ self.device,
557
+ num_train_timesteps=1000
558
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
559
 
560
+ output, _ = self.model.super_resolution(sar_tensor, continous=False, seed=seed, img_s1=sar_tensor)
561
+ return self.postprocess(output)
562
+
563
+
564
+ # ============================================================================
565
+ # High-Resolution Processor
566
+ # ============================================================================
567
+
568
+ class HighResProcessor:
569
+ def __init__(self, device="cuda"):
570
+ self.device = device
571
+ self.model = None
572
+ self.tile_size = 256
573
+
574
+ def load_model(self, num_steps=1):
575
+ print("Loading E3Diff model...")
576
+ self.model = E3DiffInference(device=self.device, num_inference_steps=num_steps)
577
+ self.num_steps = num_steps
578
 
579
  def create_blend_weights(self, tile_size, overlap):
 
 
580
  ramp = np.linspace(0, 1, overlap)
 
 
581
  weight = np.ones((tile_size, tile_size))
582
+ weight[:overlap, :] *= ramp[:, np.newaxis]
583
+ weight[-overlap:, :] *= ramp[::-1, np.newaxis]
584
+ weight[:, :overlap] *= ramp[np.newaxis, :]
585
+ weight[:, -overlap:] *= ramp[np.newaxis, ::-1]
 
 
 
586
  return weight[:, :, np.newaxis]
587
 
588
+ def process(self, image, overlap=64, num_steps=1):
589
+ if self.model is None or self.num_steps != num_steps:
590
+ self.load_model(num_steps)
591
+
 
592
  if isinstance(image, Image.Image):
593
  if image.mode != 'RGB':
594
  image = image.convert('RGB')
 
597
  img_np = image
598
 
599
  h, w = img_np.shape[:2]
600
+ tile_size = self.tile_size
601
  step = tile_size - overlap
602
 
 
603
  pad_h = (step - (h - overlap) % step) % step
604
  pad_w = (step - (w - overlap) % step) % step
605
  img_padded = np.pad(img_np, ((0, pad_h), (0, pad_w), (0, 0)), mode='reflect')
606
 
607
  h_pad, w_pad = img_padded.shape[:2]
608
 
 
609
  output = np.zeros((h_pad, w_pad, 3), dtype=np.float32)
610
  weights = np.zeros((h_pad, w_pad, 1), dtype=np.float32)
 
 
611
  blend_weight = self.create_blend_weights(tile_size, overlap)
612
 
 
613
  y_positions = list(range(0, h_pad - tile_size + 1, step))
614
  x_positions = list(range(0, w_pad - tile_size + 1, step))
615
  total_tiles = len(y_positions) * len(x_positions)
616
 
617
+ print(f"Processing {total_tiles} tiles at {w}x{h}...")
618
 
619
  tile_idx = 0
620
  for y in y_positions:
621
  for x in x_positions:
 
622
  tile = img_padded[y:y+tile_size, x:x+tile_size]
623
+ tile_pil = Image.fromarray((tile * 255).astype(np.uint8))
624
 
625
+ result_pil = self.model.translate(tile_pil, seed=42)
626
+ result = np.array(result_pil).astype(np.float32) / 255.0
 
 
 
 
 
627
 
 
 
 
 
 
 
628
  output[y:y+tile_size, x:x+tile_size] += result * blend_weight
629
  weights[y:y+tile_size, x:x+tile_size] += blend_weight
630
 
631
  tile_idx += 1
632
+ print(f" Tile {tile_idx}/{total_tiles}")
 
633
 
 
634
  output = output / (weights + 1e-8)
 
 
635
  output = output[:h, :w]
636
 
637
+ return (output * 255).astype(np.uint8)
638
 
639
+ def enhance(self, image, contrast=1.1, sharpness=1.15, color=1.1):
 
640
  if isinstance(image, np.ndarray):
641
+ image = Image.fromarray(image)
 
 
642
  image = ImageEnhance.Contrast(image).enhance(contrast)
 
643
  image = ImageEnhance.Sharpness(image).enhance(sharpness)
 
644
  image = ImageEnhance.Color(image).enhance(color)
 
645
  return image
646
 
647
 
 
649
  # Gradio Interface
650
  # ============================================================================
651
 
652
+ processor = None
653
 
654
  def load_sar_image(filepath):
655
  """Load SAR image from various formats."""
 
674
  return Image.open(filepath).convert('RGB')
675
 
676
 
677
+ def translate_sar(file, num_steps, overlap, enhance_output):
678
  """Main translation function."""
679
+ global processor
680
 
681
  if file is None:
682
  return None, None, "Please upload a SAR image"
683
 
684
+ if processor is None:
685
+ processor = HighResProcessor()
 
 
686
 
687
+ print("Processing SAR image...")
688
 
 
689
  filepath = file.name if hasattr(file, 'name') else file
690
  image = load_sar_image(filepath)
691
 
692
  w, h = image.size
693
  print(f"Input size: {w}x{h}")
694
 
 
695
  start = time.time()
696
+ result = processor.process(image, overlap=int(overlap), num_steps=int(num_steps))
 
 
 
 
 
697
  elapsed = time.time() - start
698
 
699
+ result_pil = Image.fromarray(result)
 
 
 
700
 
701
+ if enhance_output:
702
+ result_pil = processor.enhance(result_pil)
 
703
 
 
704
  tiff_path = tempfile.mktemp(suffix='.tiff')
705
  result_pil.save(tiff_path, format='TIFF', compression='lzw')
706
 
707
+ print(f"Complete in {elapsed:.1f}s!")
708
 
709
  info = f"Processed in {elapsed:.1f}s | Output: {result_pil.size[0]}x{result_pil.size[1]}"
710
 
711
  return result_pil, tiff_path, info
712
 
713
 
714
+ # Create interface
715
  with gr.Blocks(title="E3Diff: SAR-to-Optical Translation") as demo:
716
  gr.Markdown("""
717
  # 🛰️ E3Diff: High-Resolution SAR-to-Optical Translation
 
720
 
721
  - Supports full resolution processing with seamless tiling
722
  - Multiple quality levels (1-8 inference steps)
 
723
  - TIFF output for commercial use
724
  """)
725
 
 
728
  input_file = gr.File(label="SAR Input (TIFF, PNG, JPG supported)", file_types=[".tif", ".tiff", ".png", ".jpg", ".jpeg"])
729
 
730
  with gr.Row():
731
+ num_steps = gr.Slider(1, 8, value=1, step=1, label="Quality Steps (1=fast, 8=best)")
732
+ overlap = gr.Slider(16, 128, value=64, step=16, label="Tile Overlap")
733
 
734
+ enhance = gr.Checkbox(value=True, label="Apply enhancement")
735
 
736
  submit_btn = gr.Button("🚀 Translate to Optical", variant="primary")
737
 
738
  with gr.Column():
739
  output_image = gr.Image(label="Optical Output")
740
+ output_file = gr.File(label="Download TIFF")
741
  info_text = gr.Textbox(label="Processing Info")
742
 
743
  submit_btn.click(
 
748
 
749
  gr.Markdown("""
750
  ---
751
+ **Tips:** The model works best with Sentinel-1 style SAR imagery. Use steps=1 for speed, steps=4-8 for quality.
 
 
 
 
 
752
  """)
753
 
754