Dhenenjay commited on
Commit
aef5404
·
verified ·
1 Parent(s): 9680613

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +273 -267
app.py CHANGED
@@ -1,280 +1,50 @@
1
  """
2
  E3Diff: SAR-to-Optical Translation - HuggingFace Space
3
- Exact copy of working local implementation
4
  """
5
 
6
  import os
7
- import torch
8
- import torch.nn as nn
9
  import numpy as np
10
  from PIL import Image, ImageEnhance
11
  import gradio as gr
12
  import tempfile
13
  import time
14
- from huggingface_hub import hf_hub_download
15
 
16
- # Import model components (exact same as local)
17
- from unet import UNet
18
- from diffusion import GaussianDiffusion
19
 
20
  # ZeroGPU support
21
  try:
22
  import spaces
23
  GPU_AVAILABLE = True
 
24
  except ImportError:
25
  GPU_AVAILABLE = False
26
  spaces = None
 
27
 
28
 
29
- class E3DiffInference:
30
- """
31
- E3Diff Inference Pipeline - EXACT copy from local inference.py
32
- """
33
-
34
- def __init__(self, weights_path=None, device="cuda", image_size=256, num_inference_steps=1):
35
- self.device = torch.device(device if torch.cuda.is_available() else "cpu")
36
- self.image_size = image_size
37
- self.num_inference_steps = num_inference_steps
38
-
39
- print(f"[E3Diff] Initializing on device: {self.device}")
40
- print(f"[E3Diff] Image size: {image_size}x{image_size}")
41
- print(f"[E3Diff] Inference steps: {num_inference_steps}")
42
-
43
- # Build model
44
- self.model = self._build_model()
45
-
46
- # Load weights
47
- self._load_weights(weights_path)
48
-
49
- # Set to eval mode
50
- self.model.eval()
51
-
52
- print("[E3Diff] Model ready for inference!")
53
-
54
- def _build_model(self):
55
- """Build the E3Diff model architecture - exact same config."""
56
- # UNet configuration from SEN12_256_s2_test.json
57
- unet = UNet(
58
- in_channel=3, # Noisy image channels
59
- out_channel=3, # Output optical image
60
- norm_groups=16,
61
- inner_channel=64,
62
- channel_mults=[1, 2, 4, 8, 16], # Encoder/decoder channels
63
- attn_res=[], # No attention at specific resolutions
64
- res_blocks=1,
65
- dropout=0,
66
- image_size=self.image_size,
67
- condition_ch=3 # SAR condition channels
68
- )
69
-
70
- # Diffusion wrapper
71
- schedule_opt = {
72
- 'schedule': 'linear',
73
- 'n_timestep': self.num_inference_steps,
74
- 'linear_start': 1e-6,
75
- 'linear_end': 1e-2,
76
- 'ddim': 1,
77
- 'lq_noiselevel': 0
78
- }
79
-
80
- opt = {
81
- 'stage': 2,
82
- 'ddim_steps': self.num_inference_steps,
83
- 'model': {
84
- 'beta_schedule': {
85
- 'train': {'n_timestep': 1000},
86
- 'val': schedule_opt
87
- }
88
- }
89
- }
90
-
91
- model = GaussianDiffusion(
92
- denoise_fn=unet,
93
- image_size=self.image_size,
94
- channels=3,
95
- loss_type='l1',
96
- conditional=True,
97
- schedule_opt=schedule_opt,
98
- xT_noise_r=0,
99
- seed=1,
100
- opt=opt
101
- )
102
-
103
- return model.to(self.device)
104
-
105
- def _load_weights(self, weights_path):
106
- """Load pre-trained weights."""
107
- if weights_path is None:
108
- weights_path = hf_hub_download(
109
- repo_id="Dhenenjay/E3Diff-SAR2Optical",
110
- filename="I700000_E719_gen.pth"
111
- )
112
-
113
- print(f"[E3Diff] Loading weights from: {weights_path}")
114
- state_dict = torch.load(weights_path, map_location=self.device, weights_only=False)
115
- self.model.load_state_dict(state_dict, strict=False)
116
- print(f"[E3Diff] Weights loaded successfully!")
117
-
118
- def preprocess(self, image):
119
- """Preprocess input SAR image."""
120
- # Convert to RGB if grayscale
121
- if image.mode != 'RGB':
122
- image = image.convert('RGB')
123
-
124
- # Resize to model input size
125
- if image.size != (self.image_size, self.image_size):
126
- image = image.resize((self.image_size, self.image_size), Image.LANCZOS)
127
-
128
- # Convert to tensor and normalize to [-1, 1]
129
- img_np = np.array(image).astype(np.float32) / 255.0
130
- img_tensor = torch.from_numpy(img_np).permute(2, 0, 1) # HWC -> CHW
131
- img_tensor = img_tensor * 2.0 - 1.0 # [0,1] -> [-1,1]
132
-
133
- return img_tensor.unsqueeze(0).to(self.device)
134
-
135
- def postprocess(self, tensor):
136
- """Postprocess output tensor to PIL Image."""
137
- # Clamp and denormalize
138
- tensor = tensor.squeeze(0).cpu()
139
- tensor = torch.clamp(tensor, -1, 1)
140
- tensor = (tensor + 1.0) / 2.0 # [-1,1] -> [0,1]
141
-
142
- # Convert to numpy and PIL
143
- img_np = (tensor.permute(1, 2, 0).numpy() * 255).astype(np.uint8)
144
- return Image.fromarray(img_np)
145
-
146
- @torch.no_grad()
147
- def translate(self, sar_image, seed=42):
148
- """Translate SAR image to optical image."""
149
- # Set seed for reproducibility
150
- if seed is not None:
151
- torch.manual_seed(seed)
152
- np.random.seed(seed)
153
-
154
- # Preprocess
155
- sar_tensor = self.preprocess(sar_image) # [1, 3, H, W]
156
-
157
- # Set noise schedule for inference
158
- self.model.set_new_noise_schedule(
159
- {
160
- 'schedule': 'linear',
161
- 'n_timestep': self.num_inference_steps,
162
- 'linear_start': 1e-6,
163
- 'linear_end': 1e-2,
164
- 'ddim': 1,
165
- 'lq_noiselevel': 0
166
- },
167
- self.device,
168
- num_train_timesteps=1000
169
- )
170
-
171
- # Run inference
172
- output, output_onestep = self.model.super_resolution(
173
- sar_tensor,
174
- continous=False,
175
- seed=seed if seed is not None else 1,
176
- img_s1=sar_tensor
177
- )
178
-
179
- return self.postprocess(output)
180
-
181
-
182
- class HighResProcessor:
183
- """High resolution tiled processing - exact copy from process_highres.py"""
184
-
185
- def __init__(self, device="cuda"):
186
- self.device = device
187
- self.model = None
188
- self.tile_size = 256
189
-
190
- def load_model(self):
191
- print("Loading E3Diff model...")
192
- self.model = E3DiffInference(device=self.device, num_inference_steps=1)
193
-
194
- def create_blend_weights(self, tile_size, overlap):
195
- """Create smooth blending weights for seamless output."""
196
- ramp = np.linspace(0, 1, overlap)
197
- weight = np.ones((tile_size, tile_size))
198
- weight[:overlap, :] *= ramp[:, np.newaxis]
199
- weight[-overlap:, :] *= ramp[::-1, np.newaxis]
200
- weight[:, :overlap] *= ramp[np.newaxis, :]
201
- weight[:, -overlap:] *= ramp[np.newaxis, ::-1]
202
- return weight[:, :, np.newaxis]
203
-
204
- def process(self, image, overlap=64):
205
- """Process image at full resolution with seamless tiling."""
206
- if self.model is None:
207
- self.load_model()
208
-
209
- if isinstance(image, Image.Image):
210
- if image.mode != 'RGB':
211
- image = image.convert('RGB')
212
- img_np = np.array(image).astype(np.float32) / 255.0
213
- else:
214
- img_np = image
215
-
216
- h, w = img_np.shape[:2]
217
- tile_size = self.tile_size
218
- step = tile_size - overlap
219
-
220
- # Pad image
221
- pad_h = (step - (h - overlap) % step) % step
222
- pad_w = (step - (w - overlap) % step) % step
223
- img_padded = np.pad(img_np, ((0, pad_h), (0, pad_w), (0, 0)), mode='reflect')
224
-
225
- h_pad, w_pad = img_padded.shape[:2]
226
-
227
- # Output arrays
228
- output = np.zeros((h_pad, w_pad, 3), dtype=np.float32)
229
- weights = np.zeros((h_pad, w_pad, 1), dtype=np.float32)
230
- blend_weight = self.create_blend_weights(tile_size, overlap)
231
-
232
- # Calculate positions
233
- y_positions = list(range(0, h_pad - tile_size + 1, step))
234
- x_positions = list(range(0, w_pad - tile_size + 1, step))
235
- total_tiles = len(y_positions) * len(x_positions)
236
-
237
- print(f"Processing {total_tiles} tiles ({len(x_positions)}x{len(y_positions)}) at {w}x{h}...")
238
-
239
- tile_idx = 0
240
- for y in y_positions:
241
- for x in x_positions:
242
- # Extract tile
243
- tile = img_padded[y:y+tile_size, x:x+tile_size]
244
- tile_pil = Image.fromarray((tile * 255).astype(np.uint8))
245
-
246
- # Translate
247
- result_pil = self.model.translate(tile_pil, seed=42)
248
- result = np.array(result_pil).astype(np.float32) / 255.0
249
-
250
- # Blend
251
- output[y:y+tile_size, x:x+tile_size] += result * blend_weight
252
- weights[y:y+tile_size, x:x+tile_size] += blend_weight
253
-
254
- tile_idx += 1
255
- if tile_idx % 10 == 0 or tile_idx == total_tiles:
256
- print(f" Tile {tile_idx}/{total_tiles}")
257
-
258
- # Normalize
259
- output = output / (weights + 1e-8)
260
- output = output[:h, :w]
261
-
262
- return (output * 255).astype(np.uint8)
263
-
264
- def enhance(self, image, contrast=1.1, sharpness=1.2, color=1.1):
265
- """Professional post-processing."""
266
- if isinstance(image, np.ndarray):
267
- image = Image.fromarray(image)
268
-
269
- image = ImageEnhance.Contrast(image).enhance(contrast)
270
- image = ImageEnhance.Sharpness(image).enhance(sharpness)
271
- image = ImageEnhance.Color(image).enhance(color)
272
-
273
- return image
274
 
 
 
 
 
 
 
 
 
275
 
276
- # Global processor
277
- processor = None
 
 
 
 
 
 
 
278
 
279
 
280
  def load_sar_image(filepath):
@@ -300,49 +70,284 @@ def load_sar_image(filepath):
300
  return Image.open(filepath).convert('RGB')
301
 
302
 
303
- def _translate_sar_impl(file, overlap, enhance_output):
304
- """Main translation function."""
305
- global processor
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
306
 
307
  if file is None:
308
  return None, None, "Please upload a SAR image"
309
 
310
- if processor is None:
311
- processor = HighResProcessor()
 
 
 
 
 
312
 
313
- print("Processing SAR image...")
314
 
 
315
  filepath = file.name if hasattr(file, 'name') else file
 
316
  image = load_sar_image(filepath)
317
 
318
  w, h = image.size
319
- print(f"Input size: {w}x{h}")
320
 
321
  start = time.time()
322
- result = processor.process(image, overlap=int(overlap))
323
  elapsed = time.time() - start
324
 
325
  result_pil = Image.fromarray(result)
326
 
327
  if enhance_output:
328
- result_pil = processor.enhance(result_pil)
329
 
330
  tiff_path = tempfile.mktemp(suffix='.tiff')
331
  result_pil.save(tiff_path, format='TIFF', compression='lzw')
332
 
333
- print(f"Complete in {elapsed:.1f}s!")
334
 
335
  info = f"Processed in {elapsed:.1f}s | Output: {result_pil.size[0]}x{result_pil.size[1]}"
336
 
337
  return result_pil, tiff_path, info
338
 
339
 
340
- # Apply GPU decorator if available
341
  if GPU_AVAILABLE and spaces is not None:
342
- translate_sar = spaces.GPU(duration=300)(_translate_sar_impl)
 
 
343
  else:
344
- translate_sar = _translate_sar_impl
 
345
 
 
346
 
347
  # Create Gradio interface
348
  with gr.Blocks(title="E3Diff: SAR-to-Optical Translation") as demo:
@@ -379,6 +384,7 @@ with gr.Blocks(title="E3Diff: SAR-to-Optical Translation") as demo:
379
  **Note:** E3Diff is a one-step diffusion model. Multiple steps degrade quality.
380
  """)
381
 
 
382
 
383
  if __name__ == "__main__":
384
  demo.queue().launch(ssr_mode=False)
 
1
  """
2
  E3Diff: SAR-to-Optical Translation - HuggingFace Space
3
+ Fixed for ZeroGPU with lazy loading
4
  """
5
 
6
  import os
 
 
7
  import numpy as np
8
  from PIL import Image, ImageEnhance
9
  import gradio as gr
10
  import tempfile
11
  import time
 
12
 
13
+ print("[E3Diff] Starting app...")
 
 
14
 
15
  # ZeroGPU support
16
  try:
17
  import spaces
18
  GPU_AVAILABLE = True
19
+ print("[E3Diff] ZeroGPU available")
20
  except ImportError:
21
  GPU_AVAILABLE = False
22
  spaces = None
23
+ print("[E3Diff] Running without ZeroGPU")
24
 
25
 
26
+ # Lazy imports for heavy modules
27
+ _torch = None
28
+ _model_modules = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
+ def get_torch():
31
+ global _torch
32
+ if _torch is None:
33
+ print("[E3Diff] Importing torch...")
34
+ import torch
35
+ _torch = torch
36
+ print(f"[E3Diff] PyTorch {torch.__version__} loaded")
37
+ return _torch
38
 
39
+ def get_model_modules():
40
+ global _model_modules
41
+ if _model_modules is None:
42
+ print("[E3Diff] Importing model modules...")
43
+ from unet import UNet
44
+ from diffusion import GaussianDiffusion
45
+ _model_modules = (UNet, GaussianDiffusion)
46
+ print("[E3Diff] Model modules loaded")
47
+ return _model_modules
48
 
49
 
50
  def load_sar_image(filepath):
 
70
  return Image.open(filepath).convert('RGB')
71
 
72
 
73
+ def create_blend_weights(tile_size, overlap):
74
+ """Create smooth blending weights for seamless output."""
75
+ ramp = np.linspace(0, 1, overlap)
76
+ weight = np.ones((tile_size, tile_size))
77
+ weight[:overlap, :] *= ramp[:, np.newaxis]
78
+ weight[-overlap:, :] *= ramp[::-1, np.newaxis]
79
+ weight[:, :overlap] *= ramp[np.newaxis, :]
80
+ weight[:, -overlap:] *= ramp[np.newaxis, ::-1]
81
+ return weight[:, :, np.newaxis]
82
+
83
+
84
+ def build_model(device):
85
+ """Build and load the E3Diff model."""
86
+ torch = get_torch()
87
+ UNet, GaussianDiffusion = get_model_modules()
88
+ from huggingface_hub import hf_hub_download
89
+
90
+ print("[E3Diff] Building model architecture...")
91
+
92
+ image_size = 256
93
+ num_inference_steps = 1
94
+
95
+ # UNet configuration
96
+ unet = UNet(
97
+ in_channel=3,
98
+ out_channel=3,
99
+ norm_groups=16,
100
+ inner_channel=64,
101
+ channel_mults=[1, 2, 4, 8, 16],
102
+ attn_res=[],
103
+ res_blocks=1,
104
+ dropout=0,
105
+ image_size=image_size,
106
+ condition_ch=3
107
+ )
108
+
109
+ # Diffusion wrapper
110
+ schedule_opt = {
111
+ 'schedule': 'linear',
112
+ 'n_timestep': num_inference_steps,
113
+ 'linear_start': 1e-6,
114
+ 'linear_end': 1e-2,
115
+ 'ddim': 1,
116
+ 'lq_noiselevel': 0
117
+ }
118
+
119
+ opt = {
120
+ 'stage': 2,
121
+ 'ddim_steps': num_inference_steps,
122
+ 'model': {
123
+ 'beta_schedule': {
124
+ 'train': {'n_timestep': 1000},
125
+ 'val': schedule_opt
126
+ }
127
+ }
128
+ }
129
+
130
+ model = GaussianDiffusion(
131
+ denoise_fn=unet,
132
+ image_size=image_size,
133
+ channels=3,
134
+ loss_type='l1',
135
+ conditional=True,
136
+ schedule_opt=schedule_opt,
137
+ xT_noise_r=0,
138
+ seed=1,
139
+ opt=opt
140
+ )
141
+
142
+ model = model.to(device)
143
+
144
+ # Load weights
145
+ print("[E3Diff] Downloading weights...")
146
+ weights_path = hf_hub_download(
147
+ repo_id="Dhenenjay/E3Diff-SAR2Optical",
148
+ filename="I700000_E719_gen.pth"
149
+ )
150
+
151
+ print(f"[E3Diff] Loading weights from: {weights_path}")
152
+ state_dict = torch.load(weights_path, map_location=device, weights_only=False)
153
+ model.load_state_dict(state_dict, strict=False)
154
+ model.eval()
155
+
156
+ print("[E3Diff] Model ready!")
157
+ return model
158
+
159
+
160
+ def preprocess(image, device, image_size=256):
161
+ """Preprocess input SAR image."""
162
+ torch = get_torch()
163
+
164
+ if image.mode != 'RGB':
165
+ image = image.convert('RGB')
166
+
167
+ if image.size != (image_size, image_size):
168
+ image = image.resize((image_size, image_size), Image.LANCZOS)
169
+
170
+ img_np = np.array(image).astype(np.float32) / 255.0
171
+ img_tensor = torch.from_numpy(img_np).permute(2, 0, 1)
172
+ img_tensor = img_tensor * 2.0 - 1.0
173
+
174
+ return img_tensor.unsqueeze(0).to(device)
175
+
176
+
177
+ def postprocess(tensor):
178
+ """Postprocess output tensor to PIL Image."""
179
+ torch = get_torch()
180
+
181
+ tensor = tensor.squeeze(0).cpu()
182
+ tensor = torch.clamp(tensor, -1, 1)
183
+ tensor = (tensor + 1.0) / 2.0
184
+
185
+ img_np = (tensor.permute(1, 2, 0).numpy() * 255).astype(np.uint8)
186
+ return Image.fromarray(img_np)
187
+
188
+
189
+ def translate_tile(model, sar_pil, device, seed=42):
190
+ """Translate a single tile."""
191
+ torch = get_torch()
192
+
193
+ if seed is not None:
194
+ torch.manual_seed(seed)
195
+ np.random.seed(seed)
196
+
197
+ sar_tensor = preprocess(sar_pil, device)
198
+
199
+ model.set_new_noise_schedule(
200
+ {
201
+ 'schedule': 'linear',
202
+ 'n_timestep': 1,
203
+ 'linear_start': 1e-6,
204
+ 'linear_end': 1e-2,
205
+ 'ddim': 1,
206
+ 'lq_noiselevel': 0
207
+ },
208
+ device,
209
+ num_train_timesteps=1000
210
+ )
211
+
212
+ with torch.no_grad():
213
+ output, _ = model.super_resolution(
214
+ sar_tensor,
215
+ continous=False,
216
+ seed=seed if seed is not None else 1,
217
+ img_s1=sar_tensor
218
+ )
219
+
220
+ return postprocess(output)
221
+
222
+
223
+ def enhance_image(image, contrast=1.1, sharpness=1.2, color=1.1):
224
+ """Professional post-processing."""
225
+ if isinstance(image, np.ndarray):
226
+ image = Image.fromarray(image)
227
+
228
+ image = ImageEnhance.Contrast(image).enhance(contrast)
229
+ image = ImageEnhance.Sharpness(image).enhance(sharpness)
230
+ image = ImageEnhance.Color(image).enhance(color)
231
+
232
+ return image
233
+
234
+
235
+ def process_image(image, model, device, overlap=64):
236
+ """Process image at full resolution with seamless tiling."""
237
+ if isinstance(image, Image.Image):
238
+ if image.mode != 'RGB':
239
+ image = image.convert('RGB')
240
+ img_np = np.array(image).astype(np.float32) / 255.0
241
+ else:
242
+ img_np = image
243
+
244
+ h, w = img_np.shape[:2]
245
+ tile_size = 256
246
+ step = tile_size - overlap
247
+
248
+ # Pad image
249
+ pad_h = (step - (h - overlap) % step) % step
250
+ pad_w = (step - (w - overlap) % step) % step
251
+ img_padded = np.pad(img_np, ((0, pad_h), (0, pad_w), (0, 0)), mode='reflect')
252
+
253
+ h_pad, w_pad = img_padded.shape[:2]
254
+
255
+ # Output arrays
256
+ output = np.zeros((h_pad, w_pad, 3), dtype=np.float32)
257
+ weights = np.zeros((h_pad, w_pad, 1), dtype=np.float32)
258
+ blend_weight = create_blend_weights(tile_size, overlap)
259
+
260
+ # Calculate positions
261
+ y_positions = list(range(0, h_pad - tile_size + 1, step))
262
+ x_positions = list(range(0, w_pad - tile_size + 1, step))
263
+ total_tiles = len(y_positions) * len(x_positions)
264
+
265
+ print(f"[E3Diff] Processing {total_tiles} tiles ({len(x_positions)}x{len(y_positions)}) at {w}x{h}...")
266
+
267
+ tile_idx = 0
268
+ for y in y_positions:
269
+ for x in x_positions:
270
+ # Extract tile
271
+ tile = img_padded[y:y+tile_size, x:x+tile_size]
272
+ tile_pil = Image.fromarray((tile * 255).astype(np.uint8))
273
+
274
+ # Translate
275
+ result_pil = translate_tile(model, tile_pil, device, seed=42)
276
+ result = np.array(result_pil).astype(np.float32) / 255.0
277
+
278
+ # Blend
279
+ output[y:y+tile_size, x:x+tile_size] += result * blend_weight
280
+ weights[y:y+tile_size, x:x+tile_size] += blend_weight
281
+
282
+ tile_idx += 1
283
+ if tile_idx % 10 == 0 or tile_idx == total_tiles:
284
+ print(f"[E3Diff] Tile {tile_idx}/{total_tiles}")
285
+
286
+ # Normalize
287
+ output = output / (weights + 1e-8)
288
+ output = output[:h, :w]
289
+
290
+ return (output * 255).astype(np.uint8)
291
+
292
+
293
+ # Global model cache
294
+ _cached_model = None
295
+
296
+
297
+ def _translate_impl(file, overlap, enhance_output):
298
+ """Main translation function - runs on GPU."""
299
+ global _cached_model
300
 
301
  if file is None:
302
  return None, None, "Please upload a SAR image"
303
 
304
+ torch = get_torch()
305
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
306
+ print(f"[E3Diff] Using device: {device}")
307
+
308
+ # Load model (cached)
309
+ if _cached_model is None:
310
+ _cached_model = build_model(device)
311
 
312
+ model = _cached_model
313
 
314
+ # Load image
315
  filepath = file.name if hasattr(file, 'name') else file
316
+ print(f"[E3Diff] Loading: {filepath}")
317
  image = load_sar_image(filepath)
318
 
319
  w, h = image.size
320
+ print(f"[E3Diff] Input size: {w}x{h}")
321
 
322
  start = time.time()
323
+ result = process_image(image, model, device, overlap=int(overlap))
324
  elapsed = time.time() - start
325
 
326
  result_pil = Image.fromarray(result)
327
 
328
  if enhance_output:
329
+ result_pil = enhance_image(result_pil)
330
 
331
  tiff_path = tempfile.mktemp(suffix='.tiff')
332
  result_pil.save(tiff_path, format='TIFF', compression='lzw')
333
 
334
+ print(f"[E3Diff] Complete in {elapsed:.1f}s!")
335
 
336
  info = f"Processed in {elapsed:.1f}s | Output: {result_pil.size[0]}x{result_pil.size[1]}"
337
 
338
  return result_pil, tiff_path, info
339
 
340
 
341
+ # Apply GPU decorator
342
  if GPU_AVAILABLE and spaces is not None:
343
+ @spaces.GPU(duration=300)
344
+ def translate_sar(file, overlap, enhance_output):
345
+ return _translate_impl(file, overlap, enhance_output)
346
  else:
347
+ translate_sar = _translate_impl
348
+
349
 
350
+ print("[E3Diff] Building Gradio interface...")
351
 
352
  # Create Gradio interface
353
  with gr.Blocks(title="E3Diff: SAR-to-Optical Translation") as demo:
 
384
  **Note:** E3Diff is a one-step diffusion model. Multiple steps degrade quality.
385
  """)
386
 
387
+ print("[E3Diff] Launching app...")
388
 
389
  if __name__ == "__main__":
390
  demo.queue().launch(ssr_mode=False)