Dhenenjay commited on
Commit
23daa0b
·
verified ·
1 Parent(s): 7ec1792

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +684 -0
app.py ADDED
@@ -0,0 +1,684 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ E3Diff: High-Resolution SAR-to-Optical Translation
3
+ HuggingFace Spaces Deployment
4
+
5
+ Features:
6
+ - Full resolution processing with seamless tiling
7
+ - Multi-step inference for maximum quality
8
+ - TIFF output support
9
+ - Professional post-processing
10
+ """
11
+
12
+ import os
13
+ import sys
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+ import numpy as np
18
+ from PIL import Image, ImageEnhance
19
+ import gradio as gr
20
+ from pathlib import Path
21
+ import tempfile
22
+ import time
23
+ from tqdm import tqdm
24
+ from huggingface_hub import hf_hub_download
25
+
26
+ # ============================================================================
27
+ # SoftPool Implementation (Pure PyTorch)
28
+ # ============================================================================
29
+
30
+ def soft_pool2d(x, kernel_size=(2, 2), stride=None, force_inplace=False):
31
+ if stride is None:
32
+ stride = kernel_size
33
+ if isinstance(kernel_size, int):
34
+ kernel_size = (kernel_size, kernel_size)
35
+ if isinstance(stride, int):
36
+ stride = (stride, stride)
37
+
38
+ batch, channels, height, width = x.shape
39
+ kh, kw = kernel_size
40
+ sh, sw = stride
41
+ out_h = (height - kh) // sh + 1
42
+ out_w = (width - kw) // sw + 1
43
+
44
+ x_unfold = F.unfold(x, kernel_size=kernel_size, stride=stride)
45
+ x_unfold = x_unfold.view(batch, channels, kh * kw, out_h * out_w)
46
+ x_max = x_unfold.max(dim=2, keepdim=True)[0]
47
+ exp_x = torch.exp(x_unfold - x_max)
48
+ softpool = (x_unfold * exp_x).sum(dim=2) / (exp_x.sum(dim=2) + 1e-8)
49
+ return softpool.view(batch, channels, out_h, out_w)
50
+
51
+
52
+ class SoftPool2d(nn.Module):
53
+ def __init__(self, kernel_size=(2, 2), stride=None, force_inplace=False):
54
+ super(SoftPool2d, self).__init__()
55
+ self.kernel_size = kernel_size if isinstance(kernel_size, tuple) else (kernel_size, kernel_size)
56
+ self.stride = stride if stride is not None else self.kernel_size
57
+
58
+ def forward(self, x):
59
+ return soft_pool2d(x, self.kernel_size, self.stride)
60
+
61
+
62
+ # Monkey-patch SoftPool into the expected location
63
+ import sys
64
+ class SoftPoolModule:
65
+ soft_pool2d = staticmethod(soft_pool2d)
66
+ SoftPool2d = SoftPool2d
67
+ sys.modules['SoftPool'] = SoftPoolModule()
68
+
69
+ # ============================================================================
70
+ # Model Architecture
71
+ # ============================================================================
72
+
73
+ import math
74
+ from inspect import isfunction
75
+
76
+ def exists(x):
77
+ return x is not None
78
+
79
+ def default(val, d):
80
+ if exists(val):
81
+ return val
82
+ return d() if isfunction(d) else d
83
+
84
+
85
+ class PositionalEncoding(nn.Module):
86
+ def __init__(self, dim):
87
+ super().__init__()
88
+ self.dim = dim
89
+
90
+ def forward(self, noise_level):
91
+ count = self.dim // 2
92
+ step = torch.arange(count, dtype=noise_level.dtype, device=noise_level.device) / count
93
+ encoding = noise_level.unsqueeze(1) * torch.exp(-math.log(1e4) * step.unsqueeze(0))
94
+ encoding = torch.cat([torch.sin(encoding), torch.cos(encoding)], dim=-1)
95
+ return encoding
96
+
97
+
98
+ class Swish(nn.Module):
99
+ def forward(self, x):
100
+ return x * torch.sigmoid(x)
101
+
102
+
103
+ class FeatureWiseAffine(nn.Module):
104
+ def __init__(self, in_channels, out_channels, use_affine_level=False):
105
+ super(FeatureWiseAffine, self).__init__()
106
+ self.use_affine_level = use_affine_level
107
+ self.noise_func = nn.Sequential(nn.Linear(in_channels, out_channels*(1+self.use_affine_level)))
108
+
109
+ def forward(self, x, noise_embed):
110
+ batch = x.shape[0]
111
+ if self.use_affine_level:
112
+ gamma, beta = self.noise_func(noise_embed).view(batch, -1, 1, 1).chunk(2, dim=1)
113
+ x = (1 + gamma) * x + beta
114
+ else:
115
+ x = x + self.noise_func(noise_embed).view(batch, -1, 1, 1)
116
+ return x
117
+
118
+
119
+ class Upsample(nn.Module):
120
+ def __init__(self, dim):
121
+ super().__init__()
122
+ self.up = nn.Upsample(scale_factor=2, mode="nearest")
123
+ self.conv = nn.Conv2d(dim, dim, 3, padding=1)
124
+
125
+ def forward(self, x):
126
+ return self.conv(self.up(x))
127
+
128
+
129
+ class Downsample(nn.Module):
130
+ def __init__(self, dim):
131
+ super().__init__()
132
+ self.conv = nn.Conv2d(dim, dim, 3, 2, 1)
133
+
134
+ def forward(self, x):
135
+ return self.conv(x)
136
+
137
+
138
+ class Block(nn.Module):
139
+ def __init__(self, dim, dim_out, groups=32, dropout=0, stride=1):
140
+ super().__init__()
141
+ self.block = nn.Sequential(
142
+ nn.GroupNorm(groups, dim),
143
+ Swish(),
144
+ nn.Dropout(dropout) if dropout != 0 else nn.Identity(),
145
+ nn.Conv2d(dim, dim_out, 3, stride=stride, padding=1)
146
+ )
147
+
148
+ def forward(self, x):
149
+ return self.block(x)
150
+
151
+
152
+ class ResnetBlock(nn.Module):
153
+ def __init__(self, dim, dim_out, noise_level_emb_dim=None, dropout=0, use_affine_level=False, norm_groups=32):
154
+ super().__init__()
155
+ self.noise_func = FeatureWiseAffine(noise_level_emb_dim, dim_out, use_affine_level)
156
+ self.c_func = nn.Conv2d(dim_out, dim_out, 1)
157
+ self.block1 = Block(dim, dim_out, groups=norm_groups)
158
+ self.block2 = Block(dim_out, dim_out, groups=norm_groups, dropout=dropout)
159
+ self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
160
+
161
+ def forward(self, x, time_emb, c):
162
+ h = self.block1(x)
163
+ h = self.noise_func(h, time_emb)
164
+ h = self.block2(h)
165
+ h = self.c_func(c) + h
166
+ return h + self.res_conv(x)
167
+
168
+
169
+ class SelfAttention(nn.Module):
170
+ def __init__(self, in_channel, n_head=1, norm_groups=32):
171
+ super().__init__()
172
+ self.n_head = n_head
173
+ self.norm = nn.GroupNorm(norm_groups, in_channel)
174
+ self.qkv = nn.Conv2d(in_channel, in_channel * 3, 1, bias=False)
175
+ self.out = nn.Conv2d(in_channel, in_channel, 1)
176
+
177
+ def forward(self, input, t=None, save_flag=None, file_num=None):
178
+ batch, channel, height, width = input.shape
179
+ n_head = self.n_head
180
+ head_dim = channel // n_head
181
+ norm = self.norm(input)
182
+ qkv = self.qkv(norm).view(batch, n_head, head_dim * 3, height, width)
183
+ query, key, value = qkv.chunk(3, dim=2)
184
+ attn = torch.einsum("bnchw, bncyx -> bnhwyx", query, key).contiguous() / math.sqrt(channel)
185
+ attn = attn.view(batch, n_head, height, width, -1)
186
+ attn = torch.softmax(attn, -1)
187
+ attn = attn.view(batch, n_head, height, width, height, width)
188
+ out = torch.einsum("bnhwyx, bncyx -> bnchw", attn, value).contiguous()
189
+ out = self.out(out.view(batch, channel, height, width))
190
+ return out + input
191
+
192
+
193
+ class ResnetBlocWithAttn(nn.Module):
194
+ def __init__(self, dim, dim_out, *, noise_level_emb_dim=None, norm_groups=32, dropout=0, with_attn=False, size=256):
195
+ super().__init__()
196
+ self.with_attn = with_attn
197
+ self.res_block = ResnetBlock(dim, dim_out, noise_level_emb_dim, norm_groups=norm_groups, dropout=dropout)
198
+ if with_attn:
199
+ self.attn = SelfAttention(dim_out, norm_groups=norm_groups)
200
+
201
+ def forward(self, x, time_emb, c, t=0, save_flag=False, file_i=0):
202
+ x = self.res_block(x, time_emb, c)
203
+ if self.with_attn:
204
+ x = self.attn(x, t=t, save_flag=save_flag, file_num=file_i)
205
+ return x
206
+
207
+
208
+ class ResBlock_normal(nn.Module):
209
+ def __init__(self, dim, dim_out, dropout=0, norm_groups=32):
210
+ super().__init__()
211
+ self.block1 = Block(dim, dim_out, groups=norm_groups)
212
+ self.block2 = Block(dim_out, dim_out, groups=norm_groups, dropout=dropout)
213
+ self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
214
+
215
+ def forward(self, x):
216
+ h = self.block1(x)
217
+ h = self.block2(h)
218
+ return h + self.res_conv(x)
219
+
220
+
221
+ class CPEN(nn.Module):
222
+ def __init__(self, inchannel=1):
223
+ super(CPEN, self).__init__()
224
+ self.pool = SoftPool2d(kernel_size=(2,2), stride=(2,2))
225
+ self.E1 = nn.Sequential(nn.Conv2d(inchannel, 64, kernel_size=3, padding=1), Swish())
226
+ self.E2 = nn.Sequential(ResBlock_normal(64, 128, dropout=0, norm_groups=16), ResBlock_normal(128, 128, dropout=0, norm_groups=16))
227
+ self.E3 = nn.Sequential(ResBlock_normal(128, 256, dropout=0, norm_groups=16), ResBlock_normal(256, 256, dropout=0, norm_groups=16))
228
+ self.E4 = nn.Sequential(ResBlock_normal(256, 512, dropout=0, norm_groups=16), ResBlock_normal(512, 512, dropout=0, norm_groups=16))
229
+ self.E5 = nn.Sequential(ResBlock_normal(512, 512, dropout=0, norm_groups=16), ResBlock_normal(512, 1024, dropout=0, norm_groups=16))
230
+
231
+ def forward(self, x):
232
+ x1 = self.E1(x)
233
+ x2 = self.pool(x1)
234
+ x2 = self.E2(x2)
235
+ x3 = self.pool(x2)
236
+ x3 = self.E3(x3)
237
+ x4 = self.pool(x3)
238
+ x4 = self.E4(x4)
239
+ x5 = self.pool(x4)
240
+ x5 = self.E5(x5)
241
+ return x1, x2, x3, x4, x5
242
+
243
+
244
+ class UNet(nn.Module):
245
+ def __init__(self, in_channel=6, out_channel=3, inner_channel=32, norm_groups=32,
246
+ channel_mults=(1, 2, 4, 8, 8), attn_res=(8), res_blocks=3, dropout=0,
247
+ with_noise_level_emb=True, image_size=128, condition_ch=3):
248
+ super().__init__()
249
+
250
+ if with_noise_level_emb:
251
+ noise_level_channel = inner_channel
252
+ self.noise_level_mlp = nn.Sequential(
253
+ PositionalEncoding(inner_channel),
254
+ nn.Linear(inner_channel, inner_channel * 4),
255
+ Swish(),
256
+ nn.Linear(inner_channel * 4, inner_channel)
257
+ )
258
+ else:
259
+ noise_level_channel = None
260
+ self.noise_level_mlp = None
261
+
262
+ self.res_blocks = res_blocks
263
+ num_mults = len(channel_mults)
264
+ self.num_mults = num_mults
265
+ pre_channel = inner_channel
266
+ feat_channels = [pre_channel]
267
+ now_res = image_size
268
+
269
+ downs = [nn.Conv2d(in_channel, inner_channel, kernel_size=3, padding=1)]
270
+ for ind in range(num_mults):
271
+ is_last = (ind == num_mults - 1)
272
+ use_attn = (now_res in attn_res)
273
+ channel_mult = inner_channel * channel_mults[ind]
274
+ for _ in range(0, res_blocks):
275
+ downs.append(ResnetBlocWithAttn(pre_channel, channel_mult, noise_level_emb_dim=noise_level_channel,
276
+ norm_groups=norm_groups, dropout=dropout, with_attn=use_attn, size=now_res))
277
+ feat_channels.append(channel_mult)
278
+ pre_channel = channel_mult
279
+ if not is_last:
280
+ downs.append(Downsample(pre_channel))
281
+ feat_channels.append(pre_channel)
282
+ now_res = now_res // 2
283
+ self.downs = nn.ModuleList(downs)
284
+
285
+ self.mid = nn.ModuleList([
286
+ ResnetBlocWithAttn(pre_channel, pre_channel, noise_level_emb_dim=noise_level_channel,
287
+ norm_groups=norm_groups, dropout=dropout, with_attn=True, size=now_res),
288
+ ResnetBlocWithAttn(pre_channel, pre_channel, noise_level_emb_dim=noise_level_channel,
289
+ norm_groups=norm_groups, dropout=dropout, with_attn=False, size=now_res)
290
+ ])
291
+
292
+ ups = []
293
+ for ind in reversed(range(num_mults)):
294
+ is_last = (ind < 1)
295
+ use_attn = (now_res in attn_res)
296
+ channel_mult = inner_channel * channel_mults[ind]
297
+ for _ in range(0, res_blocks + 1):
298
+ ups.append(ResnetBlocWithAttn(pre_channel + feat_channels.pop(), channel_mult,
299
+ noise_level_emb_dim=noise_level_channel, norm_groups=norm_groups,
300
+ dropout=dropout, with_attn=use_attn, size=now_res))
301
+ pre_channel = channel_mult
302
+ if not is_last:
303
+ ups.append(Upsample(pre_channel))
304
+ now_res = now_res * 2
305
+ self.ups = nn.ModuleList(ups)
306
+
307
+ self.final_conv = Block(pre_channel, default(out_channel, in_channel), groups=norm_groups)
308
+ self.condition = CPEN(inchannel=condition_ch)
309
+ self.condition_ch = condition_ch
310
+
311
+ def forward(self, x, time, img_s1=None, class_label=None, return_condition=False, t_ori=0):
312
+ condition = x[:, :self.condition_ch, ...].clone()
313
+ x = x[:, self.condition_ch:, ...]
314
+
315
+ c1, c2, c3, c4, c5 = self.condition(condition)
316
+ c_base = [c1, c2, c3, c4, c5]
317
+
318
+ c = []
319
+ for i in range(len(c_base)):
320
+ for _ in range(self.res_blocks):
321
+ c.append(c_base[i])
322
+
323
+ t = self.noise_level_mlp(time) if exists(self.noise_level_mlp) else None
324
+
325
+ feats = []
326
+ i = 0
327
+ for layer in self.downs:
328
+ if isinstance(layer, ResnetBlocWithAttn):
329
+ x = layer(x, t, c[i])
330
+ i += 1
331
+ else:
332
+ x = layer(x)
333
+ feats.append(x)
334
+
335
+ for layer in self.mid:
336
+ if isinstance(layer, ResnetBlocWithAttn):
337
+ x = layer(x, t, c5)
338
+ else:
339
+ x = layer(x)
340
+
341
+ c_base = [c5, c4, c3, c2, c1]
342
+ c = []
343
+ for i in range(len(c_base)):
344
+ for _ in range(self.res_blocks + 1):
345
+ c.append(c_base[i])
346
+
347
+ i = 0
348
+ for layer in self.ups:
349
+ if isinstance(layer, ResnetBlocWithAttn):
350
+ x = layer(torch.cat((x, feats.pop()), dim=1), t, c[i])
351
+ i += 1
352
+ else:
353
+ x = layer(x)
354
+
355
+ if not return_condition:
356
+ return self.final_conv(x)
357
+ else:
358
+ return self.final_conv(x), [c1, c2, c3, c4, c5]
359
+
360
+
361
+ # ============================================================================
362
+ # E3Diff High-Resolution Inference
363
+ # ============================================================================
364
+
365
+ class E3DiffHighRes:
366
+ def __init__(self, device="cuda"):
367
+ self.device = torch.device(device if torch.cuda.is_available() else "cpu")
368
+ self.model = None
369
+ self.image_size = 256
370
+
371
+ def load_model(self, weights_path=None):
372
+ if weights_path is None:
373
+ # Download from HuggingFace
374
+ weights_path = hf_hub_download(
375
+ repo_id="Dhenenjay/E3Diff-SAR2Optical",
376
+ filename="I700000_E719_gen.pth"
377
+ )
378
+
379
+ # Build UNet
380
+ self.model = UNet(
381
+ in_channel=3,
382
+ out_channel=3,
383
+ norm_groups=16,
384
+ inner_channel=64,
385
+ channel_mults=[1, 2, 4, 8, 16],
386
+ attn_res=[],
387
+ res_blocks=1,
388
+ dropout=0,
389
+ image_size=self.image_size,
390
+ condition_ch=3
391
+ ).to(self.device)
392
+
393
+ # Load weights
394
+ state_dict = torch.load(weights_path, map_location=self.device, weights_only=False)
395
+
396
+ # Filter only UNet weights
397
+ unet_dict = {k.replace('denoise_fn.', ''): v for k, v in state_dict.items()
398
+ if k.startswith('denoise_fn.')}
399
+
400
+ self.model.load_state_dict(unet_dict, strict=False)
401
+ self.model.eval()
402
+ print(f"Model loaded on {self.device}")
403
+
404
+ @torch.no_grad()
405
+ def translate_tile(self, tile_tensor, num_steps=1):
406
+ """Translate a single 256x256 tile."""
407
+ batch_size = tile_tensor.shape[0]
408
+
409
+ # Initialize noise
410
+ noise = torch.randn(batch_size, 3, self.image_size, self.image_size, device=self.device)
411
+
412
+ # DDIM sampling
413
+ total_timesteps = 1000
414
+ ts = torch.linspace(total_timesteps, 0, num_steps + 1).to(self.device).long()
415
+
416
+ # Create beta schedule
417
+ betas = torch.linspace(1e-6, 1e-2, total_timesteps, device=self.device)
418
+ alphas = 1. - betas
419
+ alphas_cumprod = torch.cumprod(alphas, dim=0)
420
+ sqrt_alphas_cumprod_prev = torch.sqrt(torch.cat([torch.ones(1, device=self.device), alphas_cumprod]))
421
+
422
+ x = noise
423
+ for i in range(1, num_steps + 1):
424
+ cur_t = ts[i - 1] - 1
425
+ prev_t = ts[i] - 1
426
+
427
+ noise_level = sqrt_alphas_cumprod_prev[cur_t].repeat(batch_size, 1)
428
+
429
+ alpha_prod_t = alphas_cumprod[cur_t]
430
+ alpha_prod_t_prev = alphas_cumprod[prev_t] if prev_t >= 0 else torch.tensor(1.0, device=self.device)
431
+ beta_prod_t = 1 - alpha_prod_t
432
+
433
+ # Model prediction
434
+ model_input = torch.cat([tile_tensor, x], dim=1)
435
+ model_output = self.model(model_input, noise_level)
436
+
437
+ # DDIM update
438
+ pred_original = (x - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5
439
+ pred_original = pred_original.clamp(-1, 1)
440
+
441
+ sigma_2 = 0.8 * (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
442
+ pred_dir = (1 - alpha_prod_t_prev - sigma_2) ** 0.5 * model_output
443
+
444
+ if i < num_steps:
445
+ noise = torch.randn_like(x)
446
+ x = alpha_prod_t_prev ** 0.5 * pred_original + pred_dir + sigma_2 ** 0.5 * noise
447
+ else:
448
+ x = pred_original
449
+
450
+ return x
451
+
452
+ def create_blend_weights(self, tile_size, overlap):
453
+ """Create smooth blending weights for seamless tiling."""
454
+ # Linear ramp for overlap regions
455
+ ramp = np.linspace(0, 1, overlap)
456
+
457
+ # Create 2D weight matrix
458
+ weight = np.ones((tile_size, tile_size))
459
+
460
+ # Apply ramps to edges
461
+ weight[:overlap, :] *= ramp[:, np.newaxis] # Top
462
+ weight[-overlap:, :] *= ramp[::-1, np.newaxis] # Bottom
463
+ weight[:, :overlap] *= ramp[np.newaxis, :] # Left
464
+ weight[:, -overlap:] *= ramp[np.newaxis, ::-1] # Right
465
+
466
+ return weight[:, :, np.newaxis]
467
+
468
+ def translate_full_resolution(self, image, num_steps=1, overlap=64, progress_callback=None):
469
+ """
470
+ Translate full resolution image using seamless tiling.
471
+ """
472
+ # Convert to numpy if PIL
473
+ if isinstance(image, Image.Image):
474
+ if image.mode != 'RGB':
475
+ image = image.convert('RGB')
476
+ img_np = np.array(image).astype(np.float32) / 255.0
477
+ else:
478
+ img_np = image
479
+
480
+ h, w = img_np.shape[:2]
481
+ tile_size = self.image_size
482
+ step = tile_size - overlap
483
+
484
+ # Pad image to ensure full coverage
485
+ pad_h = (step - (h - overlap) % step) % step
486
+ pad_w = (step - (w - overlap) % step) % step
487
+ img_padded = np.pad(img_np, ((0, pad_h), (0, pad_w), (0, 0)), mode='reflect')
488
+
489
+ h_pad, w_pad = img_padded.shape[:2]
490
+
491
+ # Output arrays
492
+ output = np.zeros((h_pad, w_pad, 3), dtype=np.float32)
493
+ weights = np.zeros((h_pad, w_pad, 1), dtype=np.float32)
494
+
495
+ # Blending weights
496
+ blend_weight = self.create_blend_weights(tile_size, overlap)
497
+
498
+ # Calculate tile positions
499
+ y_positions = list(range(0, h_pad - tile_size + 1, step))
500
+ x_positions = list(range(0, w_pad - tile_size + 1, step))
501
+ total_tiles = len(y_positions) * len(x_positions)
502
+
503
+ print(f"Processing {total_tiles} tiles ({len(x_positions)}x{len(y_positions)})...")
504
+
505
+ tile_idx = 0
506
+ for y in y_positions:
507
+ for x in x_positions:
508
+ # Extract tile
509
+ tile = img_padded[y:y+tile_size, x:x+tile_size]
510
+
511
+ # Convert to tensor [-1, 1]
512
+ tile_tensor = torch.from_numpy(tile).permute(2, 0, 1).unsqueeze(0)
513
+ tile_tensor = tile_tensor * 2.0 - 1.0
514
+ tile_tensor = tile_tensor.to(self.device)
515
+
516
+ # Translate
517
+ result_tensor = self.translate_tile(tile_tensor, num_steps)
518
+
519
+ # Convert back to numpy [0, 1]
520
+ result = result_tensor.squeeze(0).permute(1, 2, 0).cpu().numpy()
521
+ result = (result + 1.0) / 2.0
522
+ result = np.clip(result, 0, 1)
523
+
524
+ # Add to output with blending
525
+ output[y:y+tile_size, x:x+tile_size] += result * blend_weight
526
+ weights[y:y+tile_size, x:x+tile_size] += blend_weight
527
+
528
+ tile_idx += 1
529
+ if progress_callback:
530
+ progress_callback(tile_idx / total_tiles)
531
+
532
+ # Normalize by weights
533
+ output = output / (weights + 1e-8)
534
+
535
+ # Crop to original size
536
+ output = output[:h, :w]
537
+
538
+ return output
539
+
540
+ def enhance_output(self, image, contrast=1.1, sharpness=1.15, color=1.1):
541
+ """Apply professional post-processing."""
542
+ if isinstance(image, np.ndarray):
543
+ image = Image.fromarray((image * 255).astype(np.uint8))
544
+
545
+ # Contrast
546
+ image = ImageEnhance.Contrast(image).enhance(contrast)
547
+ # Sharpness
548
+ image = ImageEnhance.Sharpness(image).enhance(sharpness)
549
+ # Color saturation
550
+ image = ImageEnhance.Color(image).enhance(color)
551
+
552
+ return image
553
+
554
+
555
+ # ============================================================================
556
+ # Gradio Interface
557
+ # ============================================================================
558
+
559
+ model = None
560
+
561
+ def load_sar_image(filepath):
562
+ """Load SAR image from various formats."""
563
+ try:
564
+ import rasterio
565
+ with rasterio.open(filepath) as src:
566
+ data = src.read(1)
567
+ if data.dtype in [np.float32, np.float64]:
568
+ valid = data[np.isfinite(data)]
569
+ if len(valid) > 0:
570
+ p2, p98 = np.percentile(valid, [2, 98])
571
+ data = np.clip(data, p2, p98)
572
+ data = ((data - p2) / (p98 - p2 + 1e-8) * 255).astype(np.uint8)
573
+ elif data.dtype == np.uint16:
574
+ p2, p98 = np.percentile(data, [2, 98])
575
+ data = np.clip(data, p2, p98)
576
+ data = ((data - p2) / (p98 - p2 + 1e-8) * 255).astype(np.uint8)
577
+ return Image.fromarray(data).convert('RGB')
578
+ except:
579
+ pass
580
+
581
+ return Image.open(filepath).convert('RGB')
582
+
583
+
584
+ def translate_sar(image, num_steps, overlap, enhance, progress=gr.Progress()):
585
+ """Main translation function."""
586
+ global model
587
+
588
+ if model is None:
589
+ progress(0, desc="Loading model...")
590
+ model = E3DiffHighRes()
591
+ model.load_model()
592
+
593
+ progress(0.1, desc="Processing image...")
594
+
595
+ # Handle file upload
596
+ if isinstance(image, str):
597
+ image = load_sar_image(image)
598
+
599
+ w, h = image.size
600
+ print(f"Input size: {w}x{h}")
601
+
602
+ # Progress callback
603
+ def update_progress(p):
604
+ progress(0.1 + 0.8 * p, desc=f"Translating... {int(p*100)}%")
605
+
606
+ # Translate
607
+ start = time.time()
608
+ result = model.translate_full_resolution(
609
+ image,
610
+ num_steps=num_steps,
611
+ overlap=overlap,
612
+ progress_callback=update_progress
613
+ )
614
+ elapsed = time.time() - start
615
+
616
+ progress(0.9, desc="Post-processing...")
617
+
618
+ # Convert to PIL
619
+ result_pil = Image.fromarray((result * 255).astype(np.uint8))
620
+
621
+ # Enhance if requested
622
+ if enhance:
623
+ result_pil = model.enhance_output(result_pil)
624
+
625
+ # Save as TIFF
626
+ tiff_path = tempfile.mktemp(suffix='.tiff')
627
+ result_pil.save(tiff_path, format='TIFF', compression='lzw')
628
+
629
+ progress(1.0, desc="Complete!")
630
+
631
+ info = f"Processed in {elapsed:.1f}s | Output: {result_pil.size[0]}x{result_pil.size[1]}"
632
+
633
+ return result_pil, tiff_path, info
634
+
635
+
636
+ # Create Gradio interface
637
+ with gr.Blocks(title="E3Diff: SAR-to-Optical Translation", theme=gr.themes.Soft()) as demo:
638
+ gr.Markdown("""
639
+ # 🛰️ E3Diff: High-Resolution SAR-to-Optical Translation
640
+
641
+ **CVPR PBVS2025 Challenge Winner** | Upload any SAR image and get a photorealistic optical translation.
642
+
643
+ - Supports full resolution processing with seamless tiling
644
+ - Multiple quality levels (1-8 inference steps)
645
+ - Professional post-processing
646
+ - TIFF output for commercial use
647
+ """)
648
+
649
+ with gr.Row():
650
+ with gr.Column():
651
+ input_image = gr.Image(label="SAR Input", type="pil")
652
+
653
+ with gr.Row():
654
+ num_steps = gr.Slider(1, 8, value=1, step=1, label="Quality Steps (1=fast, 4-8=high quality)")
655
+ overlap = gr.Slider(16, 128, value=64, step=16, label="Tile Overlap (higher=smoother)")
656
+
657
+ enhance = gr.Checkbox(value=True, label="Apply post-processing enhancement")
658
+
659
+ submit_btn = gr.Button("🚀 Translate to Optical", variant="primary")
660
+
661
+ with gr.Column():
662
+ output_image = gr.Image(label="Optical Output")
663
+ output_file = gr.File(label="Download TIFF (full resolution)")
664
+ info_text = gr.Textbox(label="Processing Info")
665
+
666
+ submit_btn.click(
667
+ fn=translate_sar,
668
+ inputs=[input_image, num_steps, overlap, enhance],
669
+ outputs=[output_image, output_file, info_text]
670
+ )
671
+
672
+ gr.Markdown("""
673
+ ---
674
+ **Tips for best results:**
675
+ - For aerial/satellite SAR: Use steps=1-2 for speed, steps=4-8 for quality
676
+ - For noisy SAR: Apply speckle filtering first (Lee or PPB filter)
677
+ - The model works best with Sentinel-1 style imagery
678
+
679
+ **Citation:** Qin et al., "Efficient End-to-End Diffusion Model for One-step SAR-to-Optical Translation", IEEE GRSL 2024
680
+ """)
681
+
682
+
683
+ if __name__ == "__main__":
684
+ demo.launch()