Spaces:
Sleeping
Sleeping
| """ | |
| Gradio app for Image Denoising + 2x Super-Resolution. | |
| Entry point for Hugging Face Spaces. | |
| """ | |
| import os | |
| import sys | |
| import torch | |
| import gradio as gr | |
| from PIL import Image | |
| from models.autoencoder import SuperResAutoencoder | |
| from training.train import load_checkpoint | |
| # --------------------------------------------------------------------------- | |
| # Config (inline — no dependency on config.py paths) | |
| # --------------------------------------------------------------------------- | |
| NOISE_TYPES = ["gaussian", "salt_pepper", "speckle"] | |
| SR_PATCH_THRESHOLD = 200 | |
| SR_PATCH_SIZE = 48 | |
| SR_OUTPUT_SIZE = 96 | |
| SR_PATCH_OVERLAP = 8 | |
| CHECKPOINT_DIR = "checkpoints" | |
| NOISE_LABELS = { | |
| "gaussian": "Gaussian (film grain / camera noise)", | |
| "salt_pepper": "Salt & Pepper (random black/white dots)", | |
| "speckle": "Speckle (grainy multiplicative noise)", | |
| } | |
| # --------------------------------------------------------------------------- | |
| # Model cache — load each checkpoint once | |
| # --------------------------------------------------------------------------- | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| _model_cache: dict = {} | |
| def get_model(noise_type: str) -> SuperResAutoencoder: | |
| if noise_type not in _model_cache: | |
| model = SuperResAutoencoder().to(DEVICE) | |
| ckpt = os.path.join(CHECKPOINT_DIR, f"best_sr_{noise_type}.pth") | |
| load_checkpoint(ckpt, model) | |
| model.eval() | |
| _model_cache[noise_type] = model | |
| return _model_cache[noise_type] | |
| # --------------------------------------------------------------------------- | |
| # Inference helpers (copied from app/inference.py, self-contained) | |
| # --------------------------------------------------------------------------- | |
| import torch.nn.functional as F | |
| import numpy as np | |
| from torchvision import transforms | |
| def _gaussian_window(size: int, sigma: float = 0.4) -> torch.Tensor: | |
| coords = torch.arange(size, dtype=torch.float32) - size // 2 | |
| g = torch.exp(-(coords ** 2) / (2 * (sigma * size) ** 2)) | |
| g2d = g[:, None] * g[None, :] | |
| return g2d / g2d.max() | |
| def _patch_upscale(img_tensor: torch.Tensor, model: SuperResAutoencoder) -> torch.Tensor: | |
| C, H, W = img_tensor.shape | |
| patch_in = SR_PATCH_SIZE | |
| patch_out = SR_OUTPUT_SIZE | |
| overlap = SR_PATCH_OVERLAP | |
| step = patch_in - overlap | |
| scale = patch_out // patch_in | |
| pad_h = (step - (H - patch_in) % step) % step | |
| pad_w = (step - (W - patch_in) % step) % step | |
| img_padded = F.pad(img_tensor, (0, pad_w, 0, pad_h), mode="reflect") | |
| _, H_pad, W_pad = img_padded.shape | |
| output = torch.zeros(C, H_pad * scale, W_pad * scale) | |
| weight = torch.zeros(1, H_pad * scale, W_pad * scale) | |
| win = _gaussian_window(patch_out).unsqueeze(0) | |
| with torch.no_grad(): | |
| for y in range(0, H_pad - patch_in + 1, step): | |
| for x in range(0, W_pad - patch_in + 1, step): | |
| patch = img_padded[:, y:y + patch_in, x:x + patch_in] | |
| pred = model(patch.unsqueeze(0).to(DEVICE)).squeeze(0).cpu().clamp(0, 1) | |
| oy, ox = y * scale, x * scale | |
| output[:, oy:oy + patch_out, ox:ox + patch_out] += pred * win | |
| weight[:, oy:oy + patch_out, ox:ox + patch_out] += win | |
| output = output / weight.clamp(min=1e-6) | |
| return output[:, : H * scale, : W * scale] | |
| def upscale(image: Image.Image, noise_type: str) -> Image.Image: | |
| image = image.convert("RGB") | |
| model = get_model(noise_type) | |
| img_tensor = transforms.ToTensor()(image) | |
| _, H, W = img_tensor.shape | |
| with torch.no_grad(): | |
| if H <= SR_PATCH_THRESHOLD and W <= SR_PATCH_THRESHOLD: | |
| out = model(img_tensor.unsqueeze(0).to(DEVICE)).squeeze(0).cpu().clamp(0, 1) | |
| else: | |
| out = _patch_upscale(img_tensor, model) | |
| return Image.fromarray((out.permute(1, 2, 0).numpy() * 255).astype(np.uint8)) | |
| # --------------------------------------------------------------------------- | |
| # Gradio UI | |
| # --------------------------------------------------------------------------- | |
| def process(image: Image.Image, noise_label: str) -> tuple[Image.Image, str]: | |
| if image is None: | |
| return None, "Please upload an image." | |
| # Map display label back to key | |
| noise_type = next(k for k, v in NOISE_LABELS.items() if v == noise_label) | |
| W, H = image.size | |
| result = upscale(image, noise_type) | |
| W2, H2 = result.size | |
| info = f"Input: {W}x{H} -> Output: {W2}x{H2} | Noise mode: {noise_type}" | |
| return result, info | |
| with gr.Blocks(title="Image Denoiser + 2x Upscaler") as demo: | |
| gr.Markdown( | |
| """ | |
| # Image Denoiser + 2x Super-Resolution | |
| Upload an image, choose the noise type that best matches it, and get back a | |
| **denoised, 2x upscaled** version. | |
| **Noise types:** | |
| - **Gaussian** — general camera/sensor noise, film grain | |
| - **Salt & Pepper** — random black/white pixel corruption | |
| - **Speckle** — grainy multiplicative noise (common in medical/satellite images) | |
| > Model: Convolutional Autoencoder trained on STL-10 (100K images, 50 epochs). | |
| > Output is always 2x the input resolution. | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_image = gr.Image(type="pil", label="Input Image") | |
| noise_selector = gr.Radio( | |
| choices=list(NOISE_LABELS.values()), | |
| value=NOISE_LABELS["gaussian"], | |
| label="Noise Type", | |
| ) | |
| run_btn = gr.Button("Denoise + Upscale", variant="primary") | |
| with gr.Column(): | |
| output_image = gr.Image(type="pil", label="Output (2x Upscaled)") | |
| info_text = gr.Textbox(label="Info", interactive=False) | |
| run_btn.click( | |
| fn=process, | |
| inputs=[input_image, noise_selector], | |
| outputs=[output_image, info_text], | |
| ) | |
| gr.Examples( | |
| examples=[], | |
| inputs=input_image, | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |