image-denoiser / app.py
Kajuto's picture
Initial commit - image denoiser + SR + MLOps stack
8b83582
"""
Gradio app for Image Denoising + 2x Super-Resolution.
Entry point for Hugging Face Spaces.
"""
import os
import sys
import torch
import gradio as gr
from PIL import Image
from models.autoencoder import SuperResAutoencoder
from training.train import load_checkpoint
# ---------------------------------------------------------------------------
# Config (inline — no dependency on config.py paths)
# ---------------------------------------------------------------------------
NOISE_TYPES = ["gaussian", "salt_pepper", "speckle"]
SR_PATCH_THRESHOLD = 200
SR_PATCH_SIZE = 48
SR_OUTPUT_SIZE = 96
SR_PATCH_OVERLAP = 8
CHECKPOINT_DIR = "checkpoints"
NOISE_LABELS = {
"gaussian": "Gaussian (film grain / camera noise)",
"salt_pepper": "Salt & Pepper (random black/white dots)",
"speckle": "Speckle (grainy multiplicative noise)",
}
# ---------------------------------------------------------------------------
# Model cache — load each checkpoint once
# ---------------------------------------------------------------------------
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
_model_cache: dict = {}
def get_model(noise_type: str) -> SuperResAutoencoder:
if noise_type not in _model_cache:
model = SuperResAutoencoder().to(DEVICE)
ckpt = os.path.join(CHECKPOINT_DIR, f"best_sr_{noise_type}.pth")
load_checkpoint(ckpt, model)
model.eval()
_model_cache[noise_type] = model
return _model_cache[noise_type]
# ---------------------------------------------------------------------------
# Inference helpers (copied from app/inference.py, self-contained)
# ---------------------------------------------------------------------------
import torch.nn.functional as F
import numpy as np
from torchvision import transforms
def _gaussian_window(size: int, sigma: float = 0.4) -> torch.Tensor:
coords = torch.arange(size, dtype=torch.float32) - size // 2
g = torch.exp(-(coords ** 2) / (2 * (sigma * size) ** 2))
g2d = g[:, None] * g[None, :]
return g2d / g2d.max()
def _patch_upscale(img_tensor: torch.Tensor, model: SuperResAutoencoder) -> torch.Tensor:
C, H, W = img_tensor.shape
patch_in = SR_PATCH_SIZE
patch_out = SR_OUTPUT_SIZE
overlap = SR_PATCH_OVERLAP
step = patch_in - overlap
scale = patch_out // patch_in
pad_h = (step - (H - patch_in) % step) % step
pad_w = (step - (W - patch_in) % step) % step
img_padded = F.pad(img_tensor, (0, pad_w, 0, pad_h), mode="reflect")
_, H_pad, W_pad = img_padded.shape
output = torch.zeros(C, H_pad * scale, W_pad * scale)
weight = torch.zeros(1, H_pad * scale, W_pad * scale)
win = _gaussian_window(patch_out).unsqueeze(0)
with torch.no_grad():
for y in range(0, H_pad - patch_in + 1, step):
for x in range(0, W_pad - patch_in + 1, step):
patch = img_padded[:, y:y + patch_in, x:x + patch_in]
pred = model(patch.unsqueeze(0).to(DEVICE)).squeeze(0).cpu().clamp(0, 1)
oy, ox = y * scale, x * scale
output[:, oy:oy + patch_out, ox:ox + patch_out] += pred * win
weight[:, oy:oy + patch_out, ox:ox + patch_out] += win
output = output / weight.clamp(min=1e-6)
return output[:, : H * scale, : W * scale]
def upscale(image: Image.Image, noise_type: str) -> Image.Image:
image = image.convert("RGB")
model = get_model(noise_type)
img_tensor = transforms.ToTensor()(image)
_, H, W = img_tensor.shape
with torch.no_grad():
if H <= SR_PATCH_THRESHOLD and W <= SR_PATCH_THRESHOLD:
out = model(img_tensor.unsqueeze(0).to(DEVICE)).squeeze(0).cpu().clamp(0, 1)
else:
out = _patch_upscale(img_tensor, model)
return Image.fromarray((out.permute(1, 2, 0).numpy() * 255).astype(np.uint8))
# ---------------------------------------------------------------------------
# Gradio UI
# ---------------------------------------------------------------------------
def process(image: Image.Image, noise_label: str) -> tuple[Image.Image, str]:
if image is None:
return None, "Please upload an image."
# Map display label back to key
noise_type = next(k for k, v in NOISE_LABELS.items() if v == noise_label)
W, H = image.size
result = upscale(image, noise_type)
W2, H2 = result.size
info = f"Input: {W}x{H} -> Output: {W2}x{H2} | Noise mode: {noise_type}"
return result, info
with gr.Blocks(title="Image Denoiser + 2x Upscaler") as demo:
gr.Markdown(
"""
# Image Denoiser + 2x Super-Resolution
Upload an image, choose the noise type that best matches it, and get back a
**denoised, 2x upscaled** version.
**Noise types:**
- **Gaussian** — general camera/sensor noise, film grain
- **Salt & Pepper** — random black/white pixel corruption
- **Speckle** — grainy multiplicative noise (common in medical/satellite images)
> Model: Convolutional Autoencoder trained on STL-10 (100K images, 50 epochs).
> Output is always 2x the input resolution.
"""
)
with gr.Row():
with gr.Column():
input_image = gr.Image(type="pil", label="Input Image")
noise_selector = gr.Radio(
choices=list(NOISE_LABELS.values()),
value=NOISE_LABELS["gaussian"],
label="Noise Type",
)
run_btn = gr.Button("Denoise + Upscale", variant="primary")
with gr.Column():
output_image = gr.Image(type="pil", label="Output (2x Upscaled)")
info_text = gr.Textbox(label="Info", interactive=False)
run_btn.click(
fn=process,
inputs=[input_image, noise_selector],
outputs=[output_image, info_text],
)
gr.Examples(
examples=[],
inputs=input_image,
)
if __name__ == "__main__":
demo.launch()