import gradio as gr from PIL import Image import os import torch from models import get_available_models from inference import run_inference device = "cuda" if torch.cuda.is_available() else "cpu" models_dict = get_available_models(model_dir="models", device=device) def save_image(img, prefix): if img is None: return None path = f"{prefix}_output.png" img.save(path) return path def process_image(img, use_x8): if img is None: return [None]*9 # 3 sliders, 3 standalone images, 3 file paths # Run SR results = run_inference(img, models_dict, x8_mode=use_x8, device=device) # Original image scaled to match the resolution for slider comparison scale = 8 if use_x8 else 4 w, h = img.size img_scaled = img.resize((w * scale, h * scale), Image.BICUBIC) outputs = [] image_outputs = [] file_outputs = [] for m in ["srcnn", "satlas", "esrgan"]: res_img = results[m] if res_img is not None: outputs.append((img_scaled, res_img)) image_outputs.append(res_img) file_outputs.append(save_image(res_img, m)) else: outputs.append(None) image_outputs.append(None) file_outputs.append(None) return (*outputs, *image_outputs, *file_outputs) css = """ .model-panel { border: 1px solid #ccc; padding: 10px; border-radius: 8px; margin-bottom: 20px; } """ with gr.Blocks(title="Super Resolution Comparative Space") as demo: gr.Markdown("# High-Res Super Resolution with Tile-Based Inference") gr.Markdown("Compare SRCNN, SpectraGAN, and ESRGAN on a single image.") with gr.Row(): with gr.Column(scale=1): input_image = gr.Image(type="pil", label="Input Image") x8_checkbox = gr.Checkbox(label="Enable x8 Post-Resize", info="bicubic x2 pass on top of the x4 result") submit_btn = gr.Button("Upscale", variant="primary") with gr.Column(scale=3): # Model panels sliders = {} images = {} files = {} names = ["srcnn", "satlas", "esrgan"] disp_names = ["SRCNN x4", "SpectraGAN x4", "ESRGAN x4"] for i, name in enumerate(names): with gr.Group(elem_classes="model-panel"): gr.Markdown(f"### {disp_names[i]}") from gradio_imageslider import ImageSlider sliders[name] = ImageSlider(label=f"{disp_names[i]} Comparison", show_label=False) images[name] = gr.Image(type="pil", label="Final Output", interactive=False) files[name] = gr.File(label=f"Download {disp_names[i]} PNG", interactive=False) # Collect outputs in order: 3 sliders, then 3 images, then 3 files outputs = [sliders["srcnn"], sliders["satlas"], sliders["esrgan"], images["srcnn"], images["satlas"], images["esrgan"], files["srcnn"], files["satlas"], files["esrgan"]] submit_btn.click( fn=process_image, inputs=[input_image, x8_checkbox], outputs=outputs ) if __name__ == "__main__": demo.launch(css=css, server_name="0.0.0.0", server_port=7860, ssr_mode=False, share=False)