SGAN / app.py
ParamAhuja
ui
6b320c6
import gradio as gr
from PIL import Image
import os
import torch
from models import get_available_models
from inference import run_inference
device = "cuda" if torch.cuda.is_available() else "cpu"
models_dict = get_available_models(model_dir="models", device=device)
def save_image(img, prefix):
if img is None:
return None
path = f"{prefix}_output.png"
img.save(path)
return path
def process_image(img, use_x8):
if img is None:
return [None]*9 # 3 sliders, 3 standalone images, 3 file paths
# Run SR
results = run_inference(img, models_dict, x8_mode=use_x8, device=device)
# Original image scaled to match the resolution for slider comparison
scale = 8 if use_x8 else 4
w, h = img.size
img_scaled = img.resize((w * scale, h * scale), Image.BICUBIC)
outputs = []
image_outputs = []
file_outputs = []
for m in ["srcnn", "satlas", "esrgan"]:
res_img = results[m]
if res_img is not None:
outputs.append((img_scaled, res_img))
image_outputs.append(res_img)
file_outputs.append(save_image(res_img, m))
else:
outputs.append(None)
image_outputs.append(None)
file_outputs.append(None)
return (*outputs, *image_outputs, *file_outputs)
css = """
.model-panel { border: 1px solid #ccc; padding: 10px; border-radius: 8px; margin-bottom: 20px; }
"""
with gr.Blocks(title="Super Resolution Comparative Space") as demo:
gr.Markdown("# High-Res Super Resolution with Tile-Based Inference")
gr.Markdown("Compare SRCNN, SpectraGAN, and ESRGAN on a single image.")
with gr.Row():
with gr.Column(scale=1):
input_image = gr.Image(type="pil", label="Input Image")
x8_checkbox = gr.Checkbox(label="Enable x8 Post-Resize", info="bicubic x2 pass on top of the x4 result")
submit_btn = gr.Button("Upscale", variant="primary")
with gr.Column(scale=3):
# Model panels
sliders = {}
images = {}
files = {}
names = ["srcnn", "satlas", "esrgan"]
disp_names = ["SRCNN x4", "SpectraGAN x4", "ESRGAN x4"]
for i, name in enumerate(names):
with gr.Group(elem_classes="model-panel"):
gr.Markdown(f"### {disp_names[i]}")
from gradio_imageslider import ImageSlider
sliders[name] = ImageSlider(label=f"{disp_names[i]} Comparison", show_label=False)
images[name] = gr.Image(type="pil", label="Final Output", interactive=False)
files[name] = gr.File(label=f"Download {disp_names[i]} PNG", interactive=False)
# Collect outputs in order: 3 sliders, then 3 images, then 3 files
outputs = [sliders["srcnn"], sliders["satlas"], sliders["esrgan"],
images["srcnn"], images["satlas"], images["esrgan"],
files["srcnn"], files["satlas"], files["esrgan"]]
submit_btn.click(
fn=process_image,
inputs=[input_image, x8_checkbox],
outputs=outputs
)
if __name__ == "__main__":
demo.launch(css=css, server_name="0.0.0.0", server_port=7860, ssr_mode=False, share=False)