Dhenenjay commited on
Commit
b4c9392
·
verified ·
1 Parent(s): 385037b

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +400 -400
app.py CHANGED
@@ -1,400 +1,400 @@
1
- """
2
- E3Diff: SAR-to-Optical Translation - HuggingFace Space
3
- Fixed for ZeroGPU with lazy loading
4
- """
5
-
6
- import os
7
- import numpy as np
8
- from PIL import Image, ImageEnhance
9
- import gradio as gr
10
- import tempfile
11
- import time
12
-
13
- print("[Axion] Starting app...")
14
-
15
- # ZeroGPU support
16
- try:
17
- import spaces
18
- GPU_AVAILABLE = True
19
- print("[Axion] ZeroGPU available")
20
- except ImportError:
21
- GPU_AVAILABLE = False
22
- spaces = None
23
- print("[Axion] Running without ZeroGPU")
24
-
25
-
26
- # Lazy imports for heavy modules
27
- _torch = None
28
- _model_modules = None
29
-
30
- def get_torch():
31
- global _torch
32
- if _torch is None:
33
- print("[Axion] Importing torch...")
34
- import torch
35
- _torch = torch
36
- print(f"[Axion] PyTorch {torch.__version__} loaded")
37
- return _torch
38
-
39
- def get_model_modules():
40
- global _model_modules
41
- if _model_modules is None:
42
- print("[Axion] Importing model modules...")
43
- from unet import UNet
44
- from diffusion import GaussianDiffusion
45
- _model_modules = (UNet, GaussianDiffusion)
46
- print("[Axion] Model modules loaded")
47
- return _model_modules
48
-
49
-
50
- def load_sar_image(filepath):
51
- """Load SAR image from various formats."""
52
- try:
53
- import rasterio
54
- with rasterio.open(filepath) as src:
55
- data = src.read(1)
56
- if data.dtype in [np.float32, np.float64]:
57
- valid = data[np.isfinite(data)]
58
- if len(valid) > 0:
59
- p2, p98 = np.percentile(valid, [2, 98])
60
- data = np.clip(data, p2, p98)
61
- data = ((data - p2) / (p98 - p2 + 1e-8) * 255).astype(np.uint8)
62
- elif data.dtype == np.uint16:
63
- p2, p98 = np.percentile(data, [2, 98])
64
- data = np.clip(data, p2, p98)
65
- data = ((data - p2) / (p98 - p2 + 1e-8) * 255).astype(np.uint8)
66
- return Image.fromarray(data).convert('RGB')
67
- except:
68
- pass
69
-
70
- return Image.open(filepath).convert('RGB')
71
-
72
-
73
- def create_blend_weights(tile_size, overlap):
74
- """Create smooth blending weights for seamless output."""
75
- ramp = np.linspace(0, 1, overlap)
76
- weight = np.ones((tile_size, tile_size))
77
- weight[:overlap, :] *= ramp[:, np.newaxis]
78
- weight[-overlap:, :] *= ramp[::-1, np.newaxis]
79
- weight[:, :overlap] *= ramp[np.newaxis, :]
80
- weight[:, -overlap:] *= ramp[np.newaxis, ::-1]
81
- return weight[:, :, np.newaxis]
82
-
83
-
84
- def build_model(device):
85
- """Build and load the E3Diff model."""
86
- torch = get_torch()
87
- UNet, GaussianDiffusion = get_model_modules()
88
- from huggingface_hub import hf_hub_download
89
-
90
- print("[Axion] Building model architecture...")
91
-
92
- image_size = 256
93
- num_inference_steps = 1
94
-
95
- # UNet configuration
96
- unet = UNet(
97
- in_channel=3,
98
- out_channel=3,
99
- norm_groups=16,
100
- inner_channel=64,
101
- channel_mults=[1, 2, 4, 8, 16],
102
- attn_res=[],
103
- res_blocks=1,
104
- dropout=0,
105
- image_size=image_size,
106
- condition_ch=3
107
- )
108
-
109
- # Diffusion wrapper
110
- schedule_opt = {
111
- 'schedule': 'linear',
112
- 'n_timestep': num_inference_steps,
113
- 'linear_start': 1e-6,
114
- 'linear_end': 1e-2,
115
- 'ddim': 1,
116
- 'lq_noiselevel': 0
117
- }
118
-
119
- opt = {
120
- 'stage': 2,
121
- 'ddim_steps': num_inference_steps,
122
- 'model': {
123
- 'beta_schedule': {
124
- 'train': {'n_timestep': 1000},
125
- 'val': schedule_opt
126
- }
127
- }
128
- }
129
-
130
- model = GaussianDiffusion(
131
- denoise_fn=unet,
132
- image_size=image_size,
133
- channels=3,
134
- loss_type='l1',
135
- conditional=True,
136
- schedule_opt=schedule_opt,
137
- xT_noise_r=0,
138
- seed=1,
139
- opt=opt
140
- )
141
-
142
- model = model.to(device)
143
-
144
- # Load weights
145
- print("[Axion] Downloading weights...")
146
- weights_path = hf_hub_download(
147
- repo_id="Dhenenjay/E3Diff-SAR2Optical",
148
- filename="I700000_E719_gen.pth"
149
- )
150
-
151
- print(f"[Axion] Loading weights from: {weights_path}")
152
- state_dict = torch.load(weights_path, map_location=device, weights_only=False)
153
- model.load_state_dict(state_dict, strict=False)
154
- model.eval()
155
-
156
- print("[Axion] Model ready!")
157
- return model
158
-
159
-
160
- def preprocess(image, device, image_size=256):
161
- """Preprocess input SAR image."""
162
- torch = get_torch()
163
-
164
- if image.mode != 'RGB':
165
- image = image.convert('RGB')
166
-
167
- if image.size != (image_size, image_size):
168
- image = image.resize((image_size, image_size), Image.LANCZOS)
169
-
170
- img_np = np.array(image).astype(np.float32) / 255.0
171
- img_tensor = torch.from_numpy(img_np).permute(2, 0, 1)
172
- img_tensor = img_tensor * 2.0 - 1.0
173
-
174
- return img_tensor.unsqueeze(0).to(device)
175
-
176
-
177
- def postprocess(tensor):
178
- """Postprocess output tensor to PIL Image."""
179
- torch = get_torch()
180
-
181
- tensor = tensor.squeeze(0).cpu()
182
- tensor = torch.clamp(tensor, -1, 1)
183
- tensor = (tensor + 1.0) / 2.0
184
-
185
- img_np = (tensor.permute(1, 2, 0).numpy() * 255).astype(np.uint8)
186
- return Image.fromarray(img_np)
187
-
188
-
189
- def translate_tile(model, sar_pil, device, seed=42):
190
- """Translate a single tile."""
191
- torch = get_torch()
192
-
193
- if seed is not None:
194
- torch.manual_seed(seed)
195
- np.random.seed(seed)
196
-
197
- sar_tensor = preprocess(sar_pil, device)
198
-
199
- model.set_new_noise_schedule(
200
- {
201
- 'schedule': 'linear',
202
- 'n_timestep': 1,
203
- 'linear_start': 1e-6,
204
- 'linear_end': 1e-2,
205
- 'ddim': 1,
206
- 'lq_noiselevel': 0
207
- },
208
- device,
209
- num_train_timesteps=1000
210
- )
211
-
212
- with torch.no_grad():
213
- output, _ = model.super_resolution(
214
- sar_tensor,
215
- continous=False,
216
- seed=seed if seed is not None else 1,
217
- img_s1=sar_tensor
218
- )
219
-
220
- return postprocess(output)
221
-
222
-
223
- def enhance_image(image, contrast=1.1, sharpness=1.2, color=1.1):
224
- """Professional post-processing."""
225
- if isinstance(image, np.ndarray):
226
- image = Image.fromarray(image)
227
-
228
- image = ImageEnhance.Contrast(image).enhance(contrast)
229
- image = ImageEnhance.Sharpness(image).enhance(sharpness)
230
- image = ImageEnhance.Color(image).enhance(color)
231
-
232
- return image
233
-
234
-
235
- def process_image(image, model, device, overlap=64):
236
- """Process image at full resolution with seamless tiling."""
237
- if isinstance(image, Image.Image):
238
- if image.mode != 'RGB':
239
- image = image.convert('RGB')
240
- img_np = np.array(image).astype(np.float32) / 255.0
241
- else:
242
- img_np = image
243
-
244
- h, w = img_np.shape[:2]
245
- tile_size = 256
246
- step = tile_size - overlap
247
-
248
- # Pad image
249
- pad_h = (step - (h - overlap) % step) % step
250
- pad_w = (step - (w - overlap) % step) % step
251
- img_padded = np.pad(img_np, ((0, pad_h), (0, pad_w), (0, 0)), mode='reflect')
252
-
253
- h_pad, w_pad = img_padded.shape[:2]
254
-
255
- # Output arrays
256
- output = np.zeros((h_pad, w_pad, 3), dtype=np.float32)
257
- weights = np.zeros((h_pad, w_pad, 1), dtype=np.float32)
258
- blend_weight = create_blend_weights(tile_size, overlap)
259
-
260
- # Calculate positions
261
- y_positions = list(range(0, h_pad - tile_size + 1, step))
262
- x_positions = list(range(0, w_pad - tile_size + 1, step))
263
- total_tiles = len(y_positions) * len(x_positions)
264
-
265
- print(f"[Axion] Processing {total_tiles} tiles ({len(x_positions)}x{len(y_positions)}) at {w}x{h}...")
266
-
267
- tile_idx = 0
268
- for y in y_positions:
269
- for x in x_positions:
270
- # Extract tile
271
- tile = img_padded[y:y+tile_size, x:x+tile_size]
272
- tile_pil = Image.fromarray((tile * 255).astype(np.uint8))
273
-
274
- # Translate
275
- result_pil = translate_tile(model, tile_pil, device, seed=42)
276
- result = np.array(result_pil).astype(np.float32) / 255.0
277
-
278
- # Blend
279
- output[y:y+tile_size, x:x+tile_size] += result * blend_weight
280
- weights[y:y+tile_size, x:x+tile_size] += blend_weight
281
-
282
- tile_idx += 1
283
- if tile_idx % 10 == 0 or tile_idx == total_tiles:
284
- print(f"[Axion] Tile {tile_idx}/{total_tiles}")
285
-
286
- # Normalize
287
- output = output / (weights + 1e-8)
288
- output = output[:h, :w]
289
-
290
- return (output * 255).astype(np.uint8)
291
-
292
-
293
- # Global model cache
294
- _cached_model = None
295
-
296
-
297
- def _translate_impl(file, overlap, enhance_output):
298
- """Main translation function - runs on GPU."""
299
- global _cached_model
300
-
301
- if file is None:
302
- return None, None, "Please upload a SAR image"
303
-
304
- torch = get_torch()
305
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
306
- print(f"[Axion] Using device: {device}")
307
-
308
- # Load model (cached)
309
- if _cached_model is None:
310
- _cached_model = build_model(device)
311
-
312
- model = _cached_model
313
-
314
- # Load image
315
- filepath = file.name if hasattr(file, 'name') else file
316
- print(f"[Axion] Loading: {filepath}")
317
- image = load_sar_image(filepath)
318
-
319
- w, h = image.size
320
- print(f"[Axion] Input size: {w}x{h}")
321
-
322
- start = time.time()
323
- result = process_image(image, model, device, overlap=int(overlap))
324
- elapsed = time.time() - start
325
-
326
- result_pil = Image.fromarray(result)
327
-
328
- if enhance_output:
329
- result_pil = enhance_image(result_pil)
330
-
331
- tiff_path = tempfile.mktemp(suffix='.tiff')
332
- result_pil.save(tiff_path, format='TIFF', compression='lzw')
333
-
334
- print(f"[Axion] Complete in {elapsed:.1f}s!")
335
-
336
- info = f"Processed in {elapsed:.1f}s | Output: {result_pil.size[0]}x{result_pil.size[1]}"
337
-
338
- return result_pil, tiff_path, info
339
-
340
-
341
- # Apply GPU decorator
342
- if GPU_AVAILABLE and spaces is not None:
343
- @spaces.GPU(duration=300)
344
- def translate_sar(file, overlap, enhance_output):
345
- return _translate_impl(file, overlap, enhance_output)
346
- else:
347
- translate_sar = _translate_impl
348
-
349
-
350
- print("[Axion] Building Gradio interface...")
351
-
352
- # Create Gradio interface
353
- with gr.Blocks(title="Axion - SAR to Optical") as demo:
354
- gr.HTML("""
355
- <style>
356
- .gradio-container { background: linear-gradient(180deg, #0a0a0a 0%, #1a1a1a 100%) !important; }
357
- </style>
358
- <div style="text-align: center; padding: 40px 20px 20px 20px;">
359
- <h1 style="font-family: 'Helvetica Neue', Arial, sans-serif; font-size: 3.2rem; font-weight: 200; color: #ffffff; margin-bottom: 0.5rem; letter-spacing: -0.02em;">SAR to Optical Image Translation</h1>
360
- <p style="font-family: 'Helvetica Neue', Arial, sans-serif; font-size: 1.1rem; font-weight: 300; color: #888888;">Transform radar imagery into crystal-clear optical views using our foundation model</p>
361
- </div>
362
- """)
363
-
364
- with gr.Row():
365
- with gr.Column():
366
- input_file = gr.File(label="Upload SAR Image", file_types=[".tif", ".tiff", ".png", ".jpg", ".jpeg"])
367
- gr.HTML("""
368
- <div style="font-size: 0.8rem; color: #666; padding: 8px 12px; background: rgba(255,255,255,0.03); border-radius: 6px; margin: 8px 0;">
369
- <strong style="color: #888;">Input Guidelines:</strong><br>
370
- Use raw SAR imagery (single-band grayscale)<br>
371
- VV polarization preferred, VH also supported<br>
372
- Any resolution supported (processed in 256×256 tiles)
373
- </div>
374
- """)
375
- with gr.Row():
376
- overlap = gr.Slider(16, 128, value=64, step=16, label="Tile Overlap")
377
- enhance = gr.Checkbox(value=True, label="Enhance Output")
378
- submit_btn = gr.Button("Translate", variant="primary")
379
-
380
- with gr.Column():
381
- output_image = gr.Image(label="Optical Output")
382
- output_file = gr.File(label="Download")
383
- info_text = gr.Textbox(label="Info", show_label=False)
384
-
385
- submit_btn.click(
386
- fn=translate_sar,
387
- inputs=[input_file, overlap, enhance],
388
- outputs=[output_image, output_file, info_text]
389
- )
390
-
391
- gr.HTML("""
392
- <div style="text-align: center; padding: 20px; color: #555; font-size: 0.85rem;">
393
- Powered by <strong style="color: #888;">Axion</strong>
394
- </div>
395
- """)
396
-
397
- print("[Axion] Launching app...")
398
-
399
- if __name__ == "__main__":
400
- demo.queue().launch(ssr_mode=False)
 
1
+ """
2
+ E3Diff: SAR-to-Optical Translation - HuggingFace Space
3
+ Fixed for ZeroGPU with lazy loading
4
+ """
5
+
6
+ import os
7
+ import numpy as np
8
+ from PIL import Image, ImageEnhance
9
+ import gradio as gr
10
+ import tempfile
11
+ import time
12
+
13
+ print("[Axion] Starting app...")
14
+
15
+ # ZeroGPU support
16
+ try:
17
+ import spaces
18
+ GPU_AVAILABLE = True
19
+ print("[Axion] ZeroGPU available")
20
+ except ImportError:
21
+ GPU_AVAILABLE = False
22
+ spaces = None
23
+ print("[Axion] Running without ZeroGPU")
24
+
25
+
26
+ # Lazy imports for heavy modules
27
+ _torch = None
28
+ _model_modules = None
29
+
30
+ def get_torch():
31
+ global _torch
32
+ if _torch is None:
33
+ print("[Axion] Importing torch...")
34
+ import torch
35
+ _torch = torch
36
+ print(f"[Axion] PyTorch {torch.__version__} loaded")
37
+ return _torch
38
+
39
+ def get_model_modules():
40
+ global _model_modules
41
+ if _model_modules is None:
42
+ print("[Axion] Importing model modules...")
43
+ from unet import UNet
44
+ from diffusion import GaussianDiffusion
45
+ _model_modules = (UNet, GaussianDiffusion)
46
+ print("[Axion] Model modules loaded")
47
+ return _model_modules
48
+
49
+
50
+ def load_sar_image(filepath):
51
+ """Load SAR image from various formats."""
52
+ try:
53
+ import rasterio
54
+ with rasterio.open(filepath) as src:
55
+ data = src.read(1)
56
+ if data.dtype in [np.float32, np.float64]:
57
+ valid = data[np.isfinite(data)]
58
+ if len(valid) > 0:
59
+ p2, p98 = np.percentile(valid, [2, 98])
60
+ data = np.clip(data, p2, p98)
61
+ data = ((data - p2) / (p98 - p2 + 1e-8) * 255).astype(np.uint8)
62
+ elif data.dtype == np.uint16:
63
+ p2, p98 = np.percentile(data, [2, 98])
64
+ data = np.clip(data, p2, p98)
65
+ data = ((data - p2) / (p98 - p2 + 1e-8) * 255).astype(np.uint8)
66
+ return Image.fromarray(data).convert('RGB')
67
+ except:
68
+ pass
69
+
70
+ return Image.open(filepath).convert('RGB')
71
+
72
+
73
+ def create_blend_weights(tile_size, overlap):
74
+ """Create smooth blending weights for seamless output."""
75
+ ramp = np.linspace(0, 1, overlap)
76
+ weight = np.ones((tile_size, tile_size))
77
+ weight[:overlap, :] *= ramp[:, np.newaxis]
78
+ weight[-overlap:, :] *= ramp[::-1, np.newaxis]
79
+ weight[:, :overlap] *= ramp[np.newaxis, :]
80
+ weight[:, -overlap:] *= ramp[np.newaxis, ::-1]
81
+ return weight[:, :, np.newaxis]
82
+
83
+
84
+ def build_model(device):
85
+ """Build and load the E3Diff model."""
86
+ torch = get_torch()
87
+ UNet, GaussianDiffusion = get_model_modules()
88
+ from huggingface_hub import hf_hub_download
89
+
90
+ print("[Axion] Building model architecture...")
91
+
92
+ image_size = 256
93
+ num_inference_steps = 1
94
+
95
+ # UNet configuration
96
+ unet = UNet(
97
+ in_channel=3,
98
+ out_channel=3,
99
+ norm_groups=16,
100
+ inner_channel=64,
101
+ channel_mults=[1, 2, 4, 8, 16],
102
+ attn_res=[],
103
+ res_blocks=1,
104
+ dropout=0,
105
+ image_size=image_size,
106
+ condition_ch=3
107
+ )
108
+
109
+ # Diffusion wrapper
110
+ schedule_opt = {
111
+ 'schedule': 'linear',
112
+ 'n_timestep': num_inference_steps,
113
+ 'linear_start': 1e-6,
114
+ 'linear_end': 1e-2,
115
+ 'ddim': 1,
116
+ 'lq_noiselevel': 0
117
+ }
118
+
119
+ opt = {
120
+ 'stage': 2,
121
+ 'ddim_steps': num_inference_steps,
122
+ 'model': {
123
+ 'beta_schedule': {
124
+ 'train': {'n_timestep': 1000},
125
+ 'val': schedule_opt
126
+ }
127
+ }
128
+ }
129
+
130
+ model = GaussianDiffusion(
131
+ denoise_fn=unet,
132
+ image_size=image_size,
133
+ channels=3,
134
+ loss_type='l1',
135
+ conditional=True,
136
+ schedule_opt=schedule_opt,
137
+ xT_noise_r=0,
138
+ seed=1,
139
+ opt=opt
140
+ )
141
+
142
+ model = model.to(device)
143
+
144
+ # Load weights
145
+ print("[Axion] Downloading weights...")
146
+ weights_path = hf_hub_download(
147
+ repo_id="Dhenenjay/Axion-S2O",
148
+ filename="I700000_E719_gen.pth"
149
+ )
150
+
151
+ print(f"[Axion] Loading weights from: {weights_path}")
152
+ state_dict = torch.load(weights_path, map_location=device, weights_only=False)
153
+ model.load_state_dict(state_dict, strict=False)
154
+ model.eval()
155
+
156
+ print("[Axion] Model ready!")
157
+ return model
158
+
159
+
160
+ def preprocess(image, device, image_size=256):
161
+ """Preprocess input SAR image."""
162
+ torch = get_torch()
163
+
164
+ if image.mode != 'RGB':
165
+ image = image.convert('RGB')
166
+
167
+ if image.size != (image_size, image_size):
168
+ image = image.resize((image_size, image_size), Image.LANCZOS)
169
+
170
+ img_np = np.array(image).astype(np.float32) / 255.0
171
+ img_tensor = torch.from_numpy(img_np).permute(2, 0, 1)
172
+ img_tensor = img_tensor * 2.0 - 1.0
173
+
174
+ return img_tensor.unsqueeze(0).to(device)
175
+
176
+
177
+ def postprocess(tensor):
178
+ """Postprocess output tensor to PIL Image."""
179
+ torch = get_torch()
180
+
181
+ tensor = tensor.squeeze(0).cpu()
182
+ tensor = torch.clamp(tensor, -1, 1)
183
+ tensor = (tensor + 1.0) / 2.0
184
+
185
+ img_np = (tensor.permute(1, 2, 0).numpy() * 255).astype(np.uint8)
186
+ return Image.fromarray(img_np)
187
+
188
+
189
+ def translate_tile(model, sar_pil, device, seed=42):
190
+ """Translate a single tile."""
191
+ torch = get_torch()
192
+
193
+ if seed is not None:
194
+ torch.manual_seed(seed)
195
+ np.random.seed(seed)
196
+
197
+ sar_tensor = preprocess(sar_pil, device)
198
+
199
+ model.set_new_noise_schedule(
200
+ {
201
+ 'schedule': 'linear',
202
+ 'n_timestep': 1,
203
+ 'linear_start': 1e-6,
204
+ 'linear_end': 1e-2,
205
+ 'ddim': 1,
206
+ 'lq_noiselevel': 0
207
+ },
208
+ device,
209
+ num_train_timesteps=1000
210
+ )
211
+
212
+ with torch.no_grad():
213
+ output, _ = model.super_resolution(
214
+ sar_tensor,
215
+ continous=False,
216
+ seed=seed if seed is not None else 1,
217
+ img_s1=sar_tensor
218
+ )
219
+
220
+ return postprocess(output)
221
+
222
+
223
+ def enhance_image(image, contrast=1.1, sharpness=1.2, color=1.1):
224
+ """Professional post-processing."""
225
+ if isinstance(image, np.ndarray):
226
+ image = Image.fromarray(image)
227
+
228
+ image = ImageEnhance.Contrast(image).enhance(contrast)
229
+ image = ImageEnhance.Sharpness(image).enhance(sharpness)
230
+ image = ImageEnhance.Color(image).enhance(color)
231
+
232
+ return image
233
+
234
+
235
+ def process_image(image, model, device, overlap=64):
236
+ """Process image at full resolution with seamless tiling."""
237
+ if isinstance(image, Image.Image):
238
+ if image.mode != 'RGB':
239
+ image = image.convert('RGB')
240
+ img_np = np.array(image).astype(np.float32) / 255.0
241
+ else:
242
+ img_np = image
243
+
244
+ h, w = img_np.shape[:2]
245
+ tile_size = 256
246
+ step = tile_size - overlap
247
+
248
+ # Pad image
249
+ pad_h = (step - (h - overlap) % step) % step
250
+ pad_w = (step - (w - overlap) % step) % step
251
+ img_padded = np.pad(img_np, ((0, pad_h), (0, pad_w), (0, 0)), mode='reflect')
252
+
253
+ h_pad, w_pad = img_padded.shape[:2]
254
+
255
+ # Output arrays
256
+ output = np.zeros((h_pad, w_pad, 3), dtype=np.float32)
257
+ weights = np.zeros((h_pad, w_pad, 1), dtype=np.float32)
258
+ blend_weight = create_blend_weights(tile_size, overlap)
259
+
260
+ # Calculate positions
261
+ y_positions = list(range(0, h_pad - tile_size + 1, step))
262
+ x_positions = list(range(0, w_pad - tile_size + 1, step))
263
+ total_tiles = len(y_positions) * len(x_positions)
264
+
265
+ print(f"[Axion] Processing {total_tiles} tiles ({len(x_positions)}x{len(y_positions)}) at {w}x{h}...")
266
+
267
+ tile_idx = 0
268
+ for y in y_positions:
269
+ for x in x_positions:
270
+ # Extract tile
271
+ tile = img_padded[y:y+tile_size, x:x+tile_size]
272
+ tile_pil = Image.fromarray((tile * 255).astype(np.uint8))
273
+
274
+ # Translate
275
+ result_pil = translate_tile(model, tile_pil, device, seed=42)
276
+ result = np.array(result_pil).astype(np.float32) / 255.0
277
+
278
+ # Blend
279
+ output[y:y+tile_size, x:x+tile_size] += result * blend_weight
280
+ weights[y:y+tile_size, x:x+tile_size] += blend_weight
281
+
282
+ tile_idx += 1
283
+ if tile_idx % 10 == 0 or tile_idx == total_tiles:
284
+ print(f"[Axion] Tile {tile_idx}/{total_tiles}")
285
+
286
+ # Normalize
287
+ output = output / (weights + 1e-8)
288
+ output = output[:h, :w]
289
+
290
+ return (output * 255).astype(np.uint8)
291
+
292
+
293
+ # Global model cache
294
+ _cached_model = None
295
+
296
+
297
+ def _translate_impl(file, overlap, enhance_output):
298
+ """Main translation function - runs on GPU."""
299
+ global _cached_model
300
+
301
+ if file is None:
302
+ return None, None, "Please upload a SAR image"
303
+
304
+ torch = get_torch()
305
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
306
+ print(f"[Axion] Using device: {device}")
307
+
308
+ # Load model (cached)
309
+ if _cached_model is None:
310
+ _cached_model = build_model(device)
311
+
312
+ model = _cached_model
313
+
314
+ # Load image
315
+ filepath = file.name if hasattr(file, 'name') else file
316
+ print(f"[Axion] Loading: {filepath}")
317
+ image = load_sar_image(filepath)
318
+
319
+ w, h = image.size
320
+ print(f"[Axion] Input size: {w}x{h}")
321
+
322
+ start = time.time()
323
+ result = process_image(image, model, device, overlap=int(overlap))
324
+ elapsed = time.time() - start
325
+
326
+ result_pil = Image.fromarray(result)
327
+
328
+ if enhance_output:
329
+ result_pil = enhance_image(result_pil)
330
+
331
+ tiff_path = tempfile.mktemp(suffix='.tiff')
332
+ result_pil.save(tiff_path, format='TIFF', compression='lzw')
333
+
334
+ print(f"[Axion] Complete in {elapsed:.1f}s!")
335
+
336
+ info = f"Processed in {elapsed:.1f}s | Output: {result_pil.size[0]}x{result_pil.size[1]}"
337
+
338
+ return result_pil, tiff_path, info
339
+
340
+
341
+ # Apply GPU decorator
342
+ if GPU_AVAILABLE and spaces is not None:
343
+ @spaces.GPU(duration=300)
344
+ def translate_sar(file, overlap, enhance_output):
345
+ return _translate_impl(file, overlap, enhance_output)
346
+ else:
347
+ translate_sar = _translate_impl
348
+
349
+
350
+ print("[Axion] Building Gradio interface...")
351
+
352
+ # Create Gradio interface
353
+ with gr.Blocks(title="Axion - SAR to Optical") as demo:
354
+ gr.HTML("""
355
+ <style>
356
+ .gradio-container { background: linear-gradient(180deg, #0a0a0a 0%, #1a1a1a 100%) !important; }
357
+ </style>
358
+ <div style="text-align: center; padding: 40px 20px 20px 20px;">
359
+ <h1 style="font-family: 'Helvetica Neue', Arial, sans-serif; font-size: 3.2rem; font-weight: 200; color: #ffffff; margin-bottom: 0.5rem; letter-spacing: -0.02em;">SAR to Optical Image Translation</h1>
360
+ <p style="font-family: 'Helvetica Neue', Arial, sans-serif; font-size: 1.1rem; font-weight: 300; color: #888888;">Transform radar imagery into crystal-clear optical views using our foundation model</p>
361
+ </div>
362
+ """)
363
+
364
+ with gr.Row():
365
+ with gr.Column():
366
+ input_file = gr.File(label="Upload SAR Image", file_types=[".tif", ".tiff", ".png", ".jpg", ".jpeg"])
367
+ gr.HTML("""
368
+ <div style="font-size: 0.8rem; color: #666; padding: 8px 12px; background: rgba(255,255,255,0.03); border-radius: 6px; margin: 8px 0;">
369
+ <strong style="color: #888;">Input Guidelines:</strong><br>
370
+ • Use raw SAR imagery (single-band grayscale)<br>
371
+ • VV polarization preferred, VH also supported<br>
372
+ • Any resolution supported (processed in 256×256 tiles)
373
+ </div>
374
+ """)
375
+ with gr.Row():
376
+ overlap = gr.Slider(16, 128, value=64, step=16, label="Tile Overlap")
377
+ enhance = gr.Checkbox(value=True, label="Enhance Output")
378
+ submit_btn = gr.Button("Translate", variant="primary")
379
+
380
+ with gr.Column():
381
+ output_image = gr.Image(label="Optical Output")
382
+ output_file = gr.File(label="Download")
383
+ info_text = gr.Textbox(label="Info", show_label=False)
384
+
385
+ submit_btn.click(
386
+ fn=translate_sar,
387
+ inputs=[input_file, overlap, enhance],
388
+ outputs=[output_image, output_file, info_text]
389
+ )
390
+
391
+ gr.HTML("""
392
+ <div style="text-align: center; padding: 20px; color: #555; font-size: 0.85rem;">
393
+ Powered by <strong style="color: #888;">Axion</strong>
394
+ </div>
395
+ """)
396
+
397
+ print("[Axion] Launching app...")
398
+
399
+ if __name__ == "__main__":
400
+ demo.queue().launch(ssr_mode=False)