Dhenenjay commited on
Commit
017441b
·
verified ·
1 Parent(s): d52c538

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +83 -46
app.py CHANGED
@@ -1,7 +1,11 @@
1
- """E3Diff: SAR-to-Optical Translation - HuggingFace Space."""
 
 
 
2
 
3
  import os
4
  import torch
 
5
  import numpy as np
6
  from PIL import Image, ImageEnhance
7
  import gradio as gr
@@ -9,7 +13,7 @@ import tempfile
9
  import time
10
  from huggingface_hub import hf_hub_download
11
 
12
- # Import model components
13
  from unet import UNet
14
  from diffusion import GaussianDiffusion
15
 
@@ -23,36 +27,47 @@ except ImportError:
23
 
24
 
25
  class E3DiffInference:
26
- """E3Diff Inference Pipeline - matches local implementation exactly."""
 
 
27
 
28
- def __init__(self, weights_path=None, device="cuda", num_inference_steps=1):
29
  self.device = torch.device(device if torch.cuda.is_available() else "cpu")
30
- self.image_size = 256
31
  self.num_inference_steps = num_inference_steps
32
 
33
  print(f"[E3Diff] Initializing on device: {self.device}")
 
34
  print(f"[E3Diff] Inference steps: {num_inference_steps}")
35
 
 
36
  self.model = self._build_model()
 
 
37
  self._load_weights(weights_path)
 
 
38
  self.model.eval()
39
- print("[E3Diff] Model ready!")
 
40
 
41
  def _build_model(self):
42
- """Build model - exact same config as local inference.py"""
 
43
  unet = UNet(
44
- in_channel=3,
45
- out_channel=3,
46
  norm_groups=16,
47
  inner_channel=64,
48
- channel_mults=[1, 2, 4, 8, 16],
49
- attn_res=[],
50
  res_blocks=1,
51
  dropout=0,
52
  image_size=self.image_size,
53
- condition_ch=3
54
  )
55
 
 
56
  schedule_opt = {
57
  'schedule': 'linear',
58
  'n_timestep': self.num_inference_steps,
@@ -88,7 +103,7 @@ class E3DiffInference:
88
  return model.to(self.device)
89
 
90
  def _load_weights(self, weights_path):
91
- """Load weights - same as local inference.py"""
92
  if weights_path is None:
93
  weights_path = hf_hub_download(
94
  repo_id="Dhenenjay/E3Diff-SAR2Optical",
@@ -98,38 +113,48 @@ class E3DiffInference:
98
  print(f"[E3Diff] Loading weights from: {weights_path}")
99
  state_dict = torch.load(weights_path, map_location=self.device, weights_only=False)
100
  self.model.load_state_dict(state_dict, strict=False)
101
- print("[E3Diff] Weights loaded!")
102
 
103
  def preprocess(self, image):
104
- """Preprocess input image."""
 
105
  if image.mode != 'RGB':
106
  image = image.convert('RGB')
 
 
107
  if image.size != (self.image_size, self.image_size):
108
  image = image.resize((self.image_size, self.image_size), Image.LANCZOS)
109
 
 
110
  img_np = np.array(image).astype(np.float32) / 255.0
111
- img_tensor = torch.from_numpy(img_np).permute(2, 0, 1)
112
- img_tensor = img_tensor * 2.0 - 1.0
 
113
  return img_tensor.unsqueeze(0).to(self.device)
114
 
115
  def postprocess(self, tensor):
116
- """Postprocess output tensor."""
 
117
  tensor = tensor.squeeze(0).cpu()
118
  tensor = torch.clamp(tensor, -1, 1)
119
- tensor = (tensor + 1.0) / 2.0
 
 
120
  img_np = (tensor.permute(1, 2, 0).numpy() * 255).astype(np.uint8)
121
  return Image.fromarray(img_np)
122
 
123
  @torch.no_grad()
124
  def translate(self, sar_image, seed=42):
125
- """Translate SAR to optical - same as local inference.py"""
 
126
  if seed is not None:
127
  torch.manual_seed(seed)
128
  np.random.seed(seed)
129
 
130
- sar_tensor = self.preprocess(sar_image)
 
131
 
132
- # Set noise schedule
133
  self.model.set_new_noise_schedule(
134
  {
135
  'schedule': 'linear',
@@ -144,25 +169,30 @@ class E3DiffInference:
144
  )
145
 
146
  # Run inference
147
- output, _ = self.model.super_resolution(sar_tensor, continous=False, seed=seed, img_s1=sar_tensor)
 
 
 
 
 
 
148
  return self.postprocess(output)
149
 
150
 
151
  class HighResProcessor:
152
- """High resolution tiled processing."""
153
 
154
  def __init__(self, device="cuda"):
155
  self.device = device
156
  self.model = None
157
  self.tile_size = 256
158
- self.num_steps = None
159
 
160
- def load_model(self, num_steps=1):
161
- print(f"Loading E3Diff model with {num_steps} steps...")
162
- self.model = E3DiffInference(device=self.device, num_inference_steps=num_steps)
163
- self.num_steps = num_steps
164
 
165
  def create_blend_weights(self, tile_size, overlap):
 
166
  ramp = np.linspace(0, 1, overlap)
167
  weight = np.ones((tile_size, tile_size))
168
  weight[:overlap, :] *= ramp[:, np.newaxis]
@@ -171,9 +201,10 @@ class HighResProcessor:
171
  weight[:, -overlap:] *= ramp[np.newaxis, ::-1]
172
  return weight[:, :, np.newaxis]
173
 
174
- def process(self, image, overlap=64, num_steps=1):
175
- if self.model is None or self.num_steps != num_steps:
176
- self.load_model(num_steps)
 
177
 
178
  if isinstance(image, Image.Image):
179
  if image.mode != 'RGB':
@@ -186,49 +217,59 @@ class HighResProcessor:
186
  tile_size = self.tile_size
187
  step = tile_size - overlap
188
 
 
189
  pad_h = (step - (h - overlap) % step) % step
190
  pad_w = (step - (w - overlap) % step) % step
191
  img_padded = np.pad(img_np, ((0, pad_h), (0, pad_w), (0, 0)), mode='reflect')
192
 
193
  h_pad, w_pad = img_padded.shape[:2]
194
 
 
195
  output = np.zeros((h_pad, w_pad, 3), dtype=np.float32)
196
  weights = np.zeros((h_pad, w_pad, 1), dtype=np.float32)
197
  blend_weight = self.create_blend_weights(tile_size, overlap)
198
 
 
199
  y_positions = list(range(0, h_pad - tile_size + 1, step))
200
  x_positions = list(range(0, w_pad - tile_size + 1, step))
201
  total_tiles = len(y_positions) * len(x_positions)
202
 
203
- print(f"Processing {total_tiles} tiles at {w}x{h}...")
204
 
205
  tile_idx = 0
206
  for y in y_positions:
207
  for x in x_positions:
 
208
  tile = img_padded[y:y+tile_size, x:x+tile_size]
209
  tile_pil = Image.fromarray((tile * 255).astype(np.uint8))
210
 
 
211
  result_pil = self.model.translate(tile_pil, seed=42)
212
  result = np.array(result_pil).astype(np.float32) / 255.0
213
 
 
214
  output[y:y+tile_size, x:x+tile_size] += result * blend_weight
215
  weights[y:y+tile_size, x:x+tile_size] += blend_weight
216
 
217
  tile_idx += 1
218
- if tile_idx % 4 == 0 or tile_idx == total_tiles:
219
  print(f" Tile {tile_idx}/{total_tiles}")
220
 
 
221
  output = output / (weights + 1e-8)
222
  output = output[:h, :w]
223
 
224
  return (output * 255).astype(np.uint8)
225
 
226
- def enhance(self, image, contrast=1.1, sharpness=1.15, color=1.1):
 
227
  if isinstance(image, np.ndarray):
228
  image = Image.fromarray(image)
 
229
  image = ImageEnhance.Contrast(image).enhance(contrast)
230
  image = ImageEnhance.Sharpness(image).enhance(sharpness)
231
  image = ImageEnhance.Color(image).enhance(color)
 
232
  return image
233
 
234
 
@@ -259,7 +300,7 @@ def load_sar_image(filepath):
259
  return Image.open(filepath).convert('RGB')
260
 
261
 
262
- def _translate_sar_impl(file, num_steps, overlap, enhance_output):
263
  """Main translation function."""
264
  global processor
265
 
@@ -278,7 +319,7 @@ def _translate_sar_impl(file, num_steps, overlap, enhance_output):
278
  print(f"Input size: {w}x{h}")
279
 
280
  start = time.time()
281
- result = processor.process(image, overlap=int(overlap), num_steps=int(num_steps))
282
  elapsed = time.time() - start
283
 
284
  result_pil = Image.fromarray(result)
@@ -310,19 +351,15 @@ with gr.Blocks(title="E3Diff: SAR-to-Optical Translation") as demo:
310
 
311
  **CVPR PBVS2025 Challenge Winner** | Upload any SAR image and get a photorealistic optical translation.
312
 
313
- - Supports full resolution processing with seamless tiling
314
- - Multiple quality levels (1-8 inference steps)
315
  - TIFF output for commercial use
316
  """)
317
 
318
  with gr.Row():
319
  with gr.Column():
320
  input_file = gr.File(label="SAR Input (TIFF, PNG, JPG)", file_types=[".tif", ".tiff", ".png", ".jpg", ".jpeg"])
321
-
322
- with gr.Row():
323
- num_steps = gr.Slider(1, 8, value=1, step=1, label="Quality Steps (1=fast, 8=best)")
324
- overlap = gr.Slider(16, 128, value=64, step=16, label="Tile Overlap")
325
-
326
  enhance = gr.Checkbox(value=True, label="Apply enhancement")
327
  submit_btn = gr.Button("🚀 Translate to Optical", variant="primary")
328
 
@@ -333,13 +370,13 @@ with gr.Blocks(title="E3Diff: SAR-to-Optical Translation") as demo:
333
 
334
  submit_btn.click(
335
  fn=translate_sar,
336
- inputs=[input_file, num_steps, overlap, enhance],
337
  outputs=[output_image, output_file, info_text]
338
  )
339
 
340
  gr.Markdown("""
341
  ---
342
- **Tips:** Use steps=1 for speed, steps=4-8 for quality. Works best with Sentinel-1 style SAR.
343
  """)
344
 
345
 
 
1
+ """
2
+ E3Diff: SAR-to-Optical Translation - HuggingFace Space
3
+ Exact copy of working local implementation
4
+ """
5
 
6
  import os
7
  import torch
8
+ import torch.nn as nn
9
  import numpy as np
10
  from PIL import Image, ImageEnhance
11
  import gradio as gr
 
13
  import time
14
  from huggingface_hub import hf_hub_download
15
 
16
+ # Import model components (exact same as local)
17
  from unet import UNet
18
  from diffusion import GaussianDiffusion
19
 
 
27
 
28
 
29
  class E3DiffInference:
30
+ """
31
+ E3Diff Inference Pipeline - EXACT copy from local inference.py
32
+ """
33
 
34
+ def __init__(self, weights_path=None, device="cuda", image_size=256, num_inference_steps=1):
35
  self.device = torch.device(device if torch.cuda.is_available() else "cpu")
36
+ self.image_size = image_size
37
  self.num_inference_steps = num_inference_steps
38
 
39
  print(f"[E3Diff] Initializing on device: {self.device}")
40
+ print(f"[E3Diff] Image size: {image_size}x{image_size}")
41
  print(f"[E3Diff] Inference steps: {num_inference_steps}")
42
 
43
+ # Build model
44
  self.model = self._build_model()
45
+
46
+ # Load weights
47
  self._load_weights(weights_path)
48
+
49
+ # Set to eval mode
50
  self.model.eval()
51
+
52
+ print("[E3Diff] Model ready for inference!")
53
 
54
  def _build_model(self):
55
+ """Build the E3Diff model architecture - exact same config."""
56
+ # UNet configuration from SEN12_256_s2_test.json
57
  unet = UNet(
58
+ in_channel=3, # Noisy image channels
59
+ out_channel=3, # Output optical image
60
  norm_groups=16,
61
  inner_channel=64,
62
+ channel_mults=[1, 2, 4, 8, 16], # Encoder/decoder channels
63
+ attn_res=[], # No attention at specific resolutions
64
  res_blocks=1,
65
  dropout=0,
66
  image_size=self.image_size,
67
+ condition_ch=3 # SAR condition channels
68
  )
69
 
70
+ # Diffusion wrapper
71
  schedule_opt = {
72
  'schedule': 'linear',
73
  'n_timestep': self.num_inference_steps,
 
103
  return model.to(self.device)
104
 
105
  def _load_weights(self, weights_path):
106
+ """Load pre-trained weights."""
107
  if weights_path is None:
108
  weights_path = hf_hub_download(
109
  repo_id="Dhenenjay/E3Diff-SAR2Optical",
 
113
  print(f"[E3Diff] Loading weights from: {weights_path}")
114
  state_dict = torch.load(weights_path, map_location=self.device, weights_only=False)
115
  self.model.load_state_dict(state_dict, strict=False)
116
+ print(f"[E3Diff] Weights loaded successfully!")
117
 
118
  def preprocess(self, image):
119
+ """Preprocess input SAR image."""
120
+ # Convert to RGB if grayscale
121
  if image.mode != 'RGB':
122
  image = image.convert('RGB')
123
+
124
+ # Resize to model input size
125
  if image.size != (self.image_size, self.image_size):
126
  image = image.resize((self.image_size, self.image_size), Image.LANCZOS)
127
 
128
+ # Convert to tensor and normalize to [-1, 1]
129
  img_np = np.array(image).astype(np.float32) / 255.0
130
+ img_tensor = torch.from_numpy(img_np).permute(2, 0, 1) # HWC -> CHW
131
+ img_tensor = img_tensor * 2.0 - 1.0 # [0,1] -> [-1,1]
132
+
133
  return img_tensor.unsqueeze(0).to(self.device)
134
 
135
  def postprocess(self, tensor):
136
+ """Postprocess output tensor to PIL Image."""
137
+ # Clamp and denormalize
138
  tensor = tensor.squeeze(0).cpu()
139
  tensor = torch.clamp(tensor, -1, 1)
140
+ tensor = (tensor + 1.0) / 2.0 # [-1,1] -> [0,1]
141
+
142
+ # Convert to numpy and PIL
143
  img_np = (tensor.permute(1, 2, 0).numpy() * 255).astype(np.uint8)
144
  return Image.fromarray(img_np)
145
 
146
  @torch.no_grad()
147
  def translate(self, sar_image, seed=42):
148
+ """Translate SAR image to optical image."""
149
+ # Set seed for reproducibility
150
  if seed is not None:
151
  torch.manual_seed(seed)
152
  np.random.seed(seed)
153
 
154
+ # Preprocess
155
+ sar_tensor = self.preprocess(sar_image) # [1, 3, H, W]
156
 
157
+ # Set noise schedule for inference
158
  self.model.set_new_noise_schedule(
159
  {
160
  'schedule': 'linear',
 
169
  )
170
 
171
  # Run inference
172
+ output, output_onestep = self.model.super_resolution(
173
+ sar_tensor,
174
+ continous=False,
175
+ seed=seed if seed is not None else 1,
176
+ img_s1=sar_tensor
177
+ )
178
+
179
  return self.postprocess(output)
180
 
181
 
182
  class HighResProcessor:
183
+ """High resolution tiled processing - exact copy from process_highres.py"""
184
 
185
  def __init__(self, device="cuda"):
186
  self.device = device
187
  self.model = None
188
  self.tile_size = 256
 
189
 
190
+ def load_model(self):
191
+ print("Loading E3Diff model...")
192
+ self.model = E3DiffInference(device=self.device, num_inference_steps=1)
 
193
 
194
  def create_blend_weights(self, tile_size, overlap):
195
+ """Create smooth blending weights for seamless output."""
196
  ramp = np.linspace(0, 1, overlap)
197
  weight = np.ones((tile_size, tile_size))
198
  weight[:overlap, :] *= ramp[:, np.newaxis]
 
201
  weight[:, -overlap:] *= ramp[np.newaxis, ::-1]
202
  return weight[:, :, np.newaxis]
203
 
204
+ def process(self, image, overlap=64):
205
+ """Process image at full resolution with seamless tiling."""
206
+ if self.model is None:
207
+ self.load_model()
208
 
209
  if isinstance(image, Image.Image):
210
  if image.mode != 'RGB':
 
217
  tile_size = self.tile_size
218
  step = tile_size - overlap
219
 
220
+ # Pad image
221
  pad_h = (step - (h - overlap) % step) % step
222
  pad_w = (step - (w - overlap) % step) % step
223
  img_padded = np.pad(img_np, ((0, pad_h), (0, pad_w), (0, 0)), mode='reflect')
224
 
225
  h_pad, w_pad = img_padded.shape[:2]
226
 
227
+ # Output arrays
228
  output = np.zeros((h_pad, w_pad, 3), dtype=np.float32)
229
  weights = np.zeros((h_pad, w_pad, 1), dtype=np.float32)
230
  blend_weight = self.create_blend_weights(tile_size, overlap)
231
 
232
+ # Calculate positions
233
  y_positions = list(range(0, h_pad - tile_size + 1, step))
234
  x_positions = list(range(0, w_pad - tile_size + 1, step))
235
  total_tiles = len(y_positions) * len(x_positions)
236
 
237
+ print(f"Processing {total_tiles} tiles ({len(x_positions)}x{len(y_positions)}) at {w}x{h}...")
238
 
239
  tile_idx = 0
240
  for y in y_positions:
241
  for x in x_positions:
242
+ # Extract tile
243
  tile = img_padded[y:y+tile_size, x:x+tile_size]
244
  tile_pil = Image.fromarray((tile * 255).astype(np.uint8))
245
 
246
+ # Translate
247
  result_pil = self.model.translate(tile_pil, seed=42)
248
  result = np.array(result_pil).astype(np.float32) / 255.0
249
 
250
+ # Blend
251
  output[y:y+tile_size, x:x+tile_size] += result * blend_weight
252
  weights[y:y+tile_size, x:x+tile_size] += blend_weight
253
 
254
  tile_idx += 1
255
+ if tile_idx % 10 == 0 or tile_idx == total_tiles:
256
  print(f" Tile {tile_idx}/{total_tiles}")
257
 
258
+ # Normalize
259
  output = output / (weights + 1e-8)
260
  output = output[:h, :w]
261
 
262
  return (output * 255).astype(np.uint8)
263
 
264
+ def enhance(self, image, contrast=1.1, sharpness=1.2, color=1.1):
265
+ """Professional post-processing."""
266
  if isinstance(image, np.ndarray):
267
  image = Image.fromarray(image)
268
+
269
  image = ImageEnhance.Contrast(image).enhance(contrast)
270
  image = ImageEnhance.Sharpness(image).enhance(sharpness)
271
  image = ImageEnhance.Color(image).enhance(color)
272
+
273
  return image
274
 
275
 
 
300
  return Image.open(filepath).convert('RGB')
301
 
302
 
303
+ def _translate_sar_impl(file, overlap, enhance_output):
304
  """Main translation function."""
305
  global processor
306
 
 
319
  print(f"Input size: {w}x{h}")
320
 
321
  start = time.time()
322
+ result = processor.process(image, overlap=int(overlap))
323
  elapsed = time.time() - start
324
 
325
  result_pil = Image.fromarray(result)
 
351
 
352
  **CVPR PBVS2025 Challenge Winner** | Upload any SAR image and get a photorealistic optical translation.
353
 
354
+ - Full resolution processing with seamless tiling
355
+ - One-step diffusion (optimized for speed & quality)
356
  - TIFF output for commercial use
357
  """)
358
 
359
  with gr.Row():
360
  with gr.Column():
361
  input_file = gr.File(label="SAR Input (TIFF, PNG, JPG)", file_types=[".tif", ".tiff", ".png", ".jpg", ".jpeg"])
362
+ overlap = gr.Slider(16, 128, value=64, step=16, label="Tile Overlap (higher=smoother)")
 
 
 
 
363
  enhance = gr.Checkbox(value=True, label="Apply enhancement")
364
  submit_btn = gr.Button("🚀 Translate to Optical", variant="primary")
365
 
 
370
 
371
  submit_btn.click(
372
  fn=translate_sar,
373
+ inputs=[input_file, overlap, enhance],
374
  outputs=[output_image, output_file, info_text]
375
  )
376
 
377
  gr.Markdown("""
378
  ---
379
+ **Note:** E3Diff is a one-step diffusion model. Multiple steps degrade quality.
380
  """)
381
 
382