Upload app.py with huggingface_hub
Browse files
app.py
CHANGED
|
@@ -581,16 +581,16 @@ def load_sar_image(filepath):
|
|
| 581 |
return Image.open(filepath).convert('RGB')
|
| 582 |
|
| 583 |
|
| 584 |
-
def translate_sar(image, num_steps, overlap, enhance
|
| 585 |
"""Main translation function."""
|
| 586 |
global model
|
| 587 |
|
| 588 |
if model is None:
|
| 589 |
-
|
| 590 |
model = E3DiffHighRes()
|
| 591 |
model.load_model()
|
| 592 |
|
| 593 |
-
|
| 594 |
|
| 595 |
# Handle file upload
|
| 596 |
if isinstance(image, str):
|
|
@@ -599,21 +599,17 @@ def translate_sar(image, num_steps, overlap, enhance, progress=gr.Progress()):
|
|
| 599 |
w, h = image.size
|
| 600 |
print(f"Input size: {w}x{h}")
|
| 601 |
|
| 602 |
-
# Progress callback
|
| 603 |
-
def update_progress(p):
|
| 604 |
-
progress(0.1 + 0.8 * p, desc=f"Translating... {int(p*100)}%")
|
| 605 |
-
|
| 606 |
# Translate
|
| 607 |
start = time.time()
|
| 608 |
result = model.translate_full_resolution(
|
| 609 |
image,
|
| 610 |
num_steps=num_steps,
|
| 611 |
overlap=overlap,
|
| 612 |
-
progress_callback=
|
| 613 |
)
|
| 614 |
elapsed = time.time() - start
|
| 615 |
|
| 616 |
-
|
| 617 |
|
| 618 |
# Convert to PIL
|
| 619 |
result_pil = Image.fromarray((result * 255).astype(np.uint8))
|
|
@@ -626,7 +622,7 @@ def translate_sar(image, num_steps, overlap, enhance, progress=gr.Progress()):
|
|
| 626 |
tiff_path = tempfile.mktemp(suffix='.tiff')
|
| 627 |
result_pil.save(tiff_path, format='TIFF', compression='lzw')
|
| 628 |
|
| 629 |
-
|
| 630 |
|
| 631 |
info = f"Processed in {elapsed:.1f}s | Output: {result_pil.size[0]}x{result_pil.size[1]}"
|
| 632 |
|
|
@@ -634,7 +630,7 @@ def translate_sar(image, num_steps, overlap, enhance, progress=gr.Progress()):
|
|
| 634 |
|
| 635 |
|
| 636 |
# Create Gradio interface
|
| 637 |
-
with gr.Blocks(title="E3Diff: SAR-to-Optical Translation"
|
| 638 |
gr.Markdown("""
|
| 639 |
# 🛰️ E3Diff: High-Resolution SAR-to-Optical Translation
|
| 640 |
|
|
|
|
| 581 |
return Image.open(filepath).convert('RGB')
|
| 582 |
|
| 583 |
|
| 584 |
+
def translate_sar(image, num_steps, overlap, enhance):
|
| 585 |
"""Main translation function."""
|
| 586 |
global model
|
| 587 |
|
| 588 |
if model is None:
|
| 589 |
+
print("Loading model...")
|
| 590 |
model = E3DiffHighRes()
|
| 591 |
model.load_model()
|
| 592 |
|
| 593 |
+
print("Processing image...")
|
| 594 |
|
| 595 |
# Handle file upload
|
| 596 |
if isinstance(image, str):
|
|
|
|
| 599 |
w, h = image.size
|
| 600 |
print(f"Input size: {w}x{h}")
|
| 601 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 602 |
# Translate
|
| 603 |
start = time.time()
|
| 604 |
result = model.translate_full_resolution(
|
| 605 |
image,
|
| 606 |
num_steps=num_steps,
|
| 607 |
overlap=overlap,
|
| 608 |
+
progress_callback=None
|
| 609 |
)
|
| 610 |
elapsed = time.time() - start
|
| 611 |
|
| 612 |
+
print("Post-processing...")
|
| 613 |
|
| 614 |
# Convert to PIL
|
| 615 |
result_pil = Image.fromarray((result * 255).astype(np.uint8))
|
|
|
|
| 622 |
tiff_path = tempfile.mktemp(suffix='.tiff')
|
| 623 |
result_pil.save(tiff_path, format='TIFF', compression='lzw')
|
| 624 |
|
| 625 |
+
print("Complete!")
|
| 626 |
|
| 627 |
info = f"Processed in {elapsed:.1f}s | Output: {result_pil.size[0]}x{result_pil.size[1]}"
|
| 628 |
|
|
|
|
| 630 |
|
| 631 |
|
| 632 |
# Create Gradio interface
|
| 633 |
+
with gr.Blocks(title="E3Diff: SAR-to-Optical Translation") as demo:
|
| 634 |
gr.Markdown("""
|
| 635 |
# 🛰️ E3Diff: High-Resolution SAR-to-Optical Translation
|
| 636 |
|