Dhenenjay commited on
Commit
298f49d
·
verified ·
1 Parent(s): 98d98f5

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +20 -2
app.py CHANGED
@@ -22,6 +22,14 @@ import time
22
  from functools import partial
23
  from huggingface_hub import hf_hub_download
24
 
 
 
 
 
 
 
 
 
25
  # ============================================================================
26
  # SoftPool Implementation (Pure PyTorch)
27
  # ============================================================================
@@ -160,6 +168,9 @@ class ResnetBlock(nn.Module):
160
  h = self.block1(x)
161
  h = self.noise_func(h, time_emb)
162
  h = self.block2(h)
 
 
 
163
  h = self.c_func(c) + h
164
  return h + self.res_conv(x)
165
 
@@ -674,8 +685,8 @@ def load_sar_image(filepath):
674
  return Image.open(filepath).convert('RGB')
675
 
676
 
677
- def translate_sar(file, num_steps, overlap, enhance_output):
678
- """Main translation function."""
679
  global processor
680
 
681
  if file is None:
@@ -711,6 +722,13 @@ def translate_sar(file, num_steps, overlap, enhance_output):
711
  return result_pil, tiff_path, info
712
 
713
 
 
 
 
 
 
 
 
714
  # Create interface
715
  with gr.Blocks(title="E3Diff: SAR-to-Optical Translation") as demo:
716
  gr.Markdown("""
 
22
  from functools import partial
23
  from huggingface_hub import hf_hub_download
24
 
25
+ # ZeroGPU support
26
+ try:
27
+ import spaces
28
+ GPU_AVAILABLE = True
29
+ except ImportError:
30
+ GPU_AVAILABLE = False
31
+ spaces = None
32
+
33
  # ============================================================================
34
  # SoftPool Implementation (Pure PyTorch)
35
  # ============================================================================
 
168
  h = self.block1(x)
169
  h = self.noise_func(h, time_emb)
170
  h = self.block2(h)
171
+ # Resize condition features to match spatial size
172
+ if c.shape[2:] != h.shape[2:]:
173
+ c = F.interpolate(c, size=h.shape[2:], mode='bilinear', align_corners=False)
174
  h = self.c_func(c) + h
175
  return h + self.res_conv(x)
176
 
 
685
  return Image.open(filepath).convert('RGB')
686
 
687
 
688
+ def _translate_sar_impl(file, num_steps, overlap, enhance_output):
689
+ """Main translation function implementation."""
690
  global processor
691
 
692
  if file is None:
 
722
  return result_pil, tiff_path, info
723
 
724
 
725
+ # Apply GPU decorator if available
726
+ if GPU_AVAILABLE and spaces is not None:
727
+ translate_sar = spaces.GPU(duration=300)(_translate_sar_impl)
728
+ else:
729
+ translate_sar = _translate_sar_impl
730
+
731
+
732
  # Create interface
733
  with gr.Blocks(title="E3Diff: SAR-to-Optical Translation") as demo:
734
  gr.Markdown("""