| import gradio as gr |
| from PIL import Image |
| import os |
| import torch |
| from models import get_available_models |
| from inference import run_inference |
|
|
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| models_dict = get_available_models(model_dir="models", device=device) |
|
|
| def save_image(img, prefix): |
| if img is None: |
| return None |
| path = f"{prefix}_output.png" |
| img.save(path) |
| return path |
|
|
| def process_image(img, use_x8): |
| if img is None: |
| return [None]*9 |
| |
| |
| results = run_inference(img, models_dict, x8_mode=use_x8, device=device) |
| |
| |
| scale = 8 if use_x8 else 4 |
| w, h = img.size |
| img_scaled = img.resize((w * scale, h * scale), Image.BICUBIC) |
| |
| outputs = [] |
| image_outputs = [] |
| file_outputs = [] |
| |
| for m in ["srcnn", "satlas", "esrgan"]: |
| res_img = results[m] |
| if res_img is not None: |
| outputs.append((img_scaled, res_img)) |
| image_outputs.append(res_img) |
| file_outputs.append(save_image(res_img, m)) |
| else: |
| outputs.append(None) |
| image_outputs.append(None) |
| file_outputs.append(None) |
| |
| return (*outputs, *image_outputs, *file_outputs) |
|
|
| css = """ |
| .model-panel { border: 1px solid #ccc; padding: 10px; border-radius: 8px; margin-bottom: 20px; } |
| """ |
|
|
| with gr.Blocks(title="Super Resolution Comparative Space") as demo: |
| gr.Markdown("# High-Res Super Resolution with Tile-Based Inference") |
| gr.Markdown("Compare SRCNN, SpectraGAN, and ESRGAN on a single image.") |
| |
| with gr.Row(): |
| with gr.Column(scale=1): |
| input_image = gr.Image(type="pil", label="Input Image") |
| x8_checkbox = gr.Checkbox(label="Enable x8 Post-Resize", info="bicubic x2 pass on top of the x4 result") |
| submit_btn = gr.Button("Upscale", variant="primary") |
| |
| with gr.Column(scale=3): |
| |
| sliders = {} |
| images = {} |
| files = {} |
| names = ["srcnn", "satlas", "esrgan"] |
| disp_names = ["SRCNN x4", "SpectraGAN x4", "ESRGAN x4"] |
| for i, name in enumerate(names): |
| with gr.Group(elem_classes="model-panel"): |
| gr.Markdown(f"### {disp_names[i]}") |
| from gradio_imageslider import ImageSlider |
| sliders[name] = ImageSlider(label=f"{disp_names[i]} Comparison", show_label=False) |
| images[name] = gr.Image(type="pil", label="Final Output", interactive=False) |
| files[name] = gr.File(label=f"Download {disp_names[i]} PNG", interactive=False) |
| |
| |
| outputs = [sliders["srcnn"], sliders["satlas"], sliders["esrgan"], |
| images["srcnn"], images["satlas"], images["esrgan"], |
| files["srcnn"], files["satlas"], files["esrgan"]] |
| |
| submit_btn.click( |
| fn=process_image, |
| inputs=[input_image, x8_checkbox], |
| outputs=outputs |
| ) |
|
|
| if __name__ == "__main__": |
| demo.launch(css=css, server_name="0.0.0.0", server_port=7860, ssr_mode=False, share=False) |
|
|