""" Gradio app for Image Denoising + 2x Super-Resolution. Entry point for Hugging Face Spaces. """ import os import sys import torch import gradio as gr from PIL import Image from models.autoencoder import SuperResAutoencoder from training.train import load_checkpoint # --------------------------------------------------------------------------- # Config (inline — no dependency on config.py paths) # --------------------------------------------------------------------------- NOISE_TYPES = ["gaussian", "salt_pepper", "speckle"] SR_PATCH_THRESHOLD = 200 SR_PATCH_SIZE = 48 SR_OUTPUT_SIZE = 96 SR_PATCH_OVERLAP = 8 CHECKPOINT_DIR = "checkpoints" NOISE_LABELS = { "gaussian": "Gaussian (film grain / camera noise)", "salt_pepper": "Salt & Pepper (random black/white dots)", "speckle": "Speckle (grainy multiplicative noise)", } # --------------------------------------------------------------------------- # Model cache — load each checkpoint once # --------------------------------------------------------------------------- DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") _model_cache: dict = {} def get_model(noise_type: str) -> SuperResAutoencoder: if noise_type not in _model_cache: model = SuperResAutoencoder().to(DEVICE) ckpt = os.path.join(CHECKPOINT_DIR, f"best_sr_{noise_type}.pth") load_checkpoint(ckpt, model) model.eval() _model_cache[noise_type] = model return _model_cache[noise_type] # --------------------------------------------------------------------------- # Inference helpers (copied from app/inference.py, self-contained) # --------------------------------------------------------------------------- import torch.nn.functional as F import numpy as np from torchvision import transforms def _gaussian_window(size: int, sigma: float = 0.4) -> torch.Tensor: coords = torch.arange(size, dtype=torch.float32) - size // 2 g = torch.exp(-(coords ** 2) / (2 * (sigma * size) ** 2)) g2d = g[:, None] * g[None, :] return g2d / g2d.max() def _patch_upscale(img_tensor: torch.Tensor, model: SuperResAutoencoder) -> torch.Tensor: C, H, W = img_tensor.shape patch_in = SR_PATCH_SIZE patch_out = SR_OUTPUT_SIZE overlap = SR_PATCH_OVERLAP step = patch_in - overlap scale = patch_out // patch_in pad_h = (step - (H - patch_in) % step) % step pad_w = (step - (W - patch_in) % step) % step img_padded = F.pad(img_tensor, (0, pad_w, 0, pad_h), mode="reflect") _, H_pad, W_pad = img_padded.shape output = torch.zeros(C, H_pad * scale, W_pad * scale) weight = torch.zeros(1, H_pad * scale, W_pad * scale) win = _gaussian_window(patch_out).unsqueeze(0) with torch.no_grad(): for y in range(0, H_pad - patch_in + 1, step): for x in range(0, W_pad - patch_in + 1, step): patch = img_padded[:, y:y + patch_in, x:x + patch_in] pred = model(patch.unsqueeze(0).to(DEVICE)).squeeze(0).cpu().clamp(0, 1) oy, ox = y * scale, x * scale output[:, oy:oy + patch_out, ox:ox + patch_out] += pred * win weight[:, oy:oy + patch_out, ox:ox + patch_out] += win output = output / weight.clamp(min=1e-6) return output[:, : H * scale, : W * scale] def upscale(image: Image.Image, noise_type: str) -> Image.Image: image = image.convert("RGB") model = get_model(noise_type) img_tensor = transforms.ToTensor()(image) _, H, W = img_tensor.shape with torch.no_grad(): if H <= SR_PATCH_THRESHOLD and W <= SR_PATCH_THRESHOLD: out = model(img_tensor.unsqueeze(0).to(DEVICE)).squeeze(0).cpu().clamp(0, 1) else: out = _patch_upscale(img_tensor, model) return Image.fromarray((out.permute(1, 2, 0).numpy() * 255).astype(np.uint8)) # --------------------------------------------------------------------------- # Gradio UI # --------------------------------------------------------------------------- def process(image: Image.Image, noise_label: str) -> tuple[Image.Image, str]: if image is None: return None, "Please upload an image." # Map display label back to key noise_type = next(k for k, v in NOISE_LABELS.items() if v == noise_label) W, H = image.size result = upscale(image, noise_type) W2, H2 = result.size info = f"Input: {W}x{H} -> Output: {W2}x{H2} | Noise mode: {noise_type}" return result, info with gr.Blocks(title="Image Denoiser + 2x Upscaler") as demo: gr.Markdown( """ # Image Denoiser + 2x Super-Resolution Upload an image, choose the noise type that best matches it, and get back a **denoised, 2x upscaled** version. **Noise types:** - **Gaussian** — general camera/sensor noise, film grain - **Salt & Pepper** — random black/white pixel corruption - **Speckle** — grainy multiplicative noise (common in medical/satellite images) > Model: Convolutional Autoencoder trained on STL-10 (100K images, 50 epochs). > Output is always 2x the input resolution. """ ) with gr.Row(): with gr.Column(): input_image = gr.Image(type="pil", label="Input Image") noise_selector = gr.Radio( choices=list(NOISE_LABELS.values()), value=NOISE_LABELS["gaussian"], label="Noise Type", ) run_btn = gr.Button("Denoise + Upscale", variant="primary") with gr.Column(): output_image = gr.Image(type="pil", label="Output (2x Upscaled)") info_text = gr.Textbox(label="Info", interactive=False) run_btn.click( fn=process, inputs=[input_image, noise_selector], outputs=[output_image, info_text], ) gr.Examples( examples=[], inputs=input_image, ) if __name__ == "__main__": demo.launch()