Noise2Vision / app.py
AsadAnalyst's picture
Update app.py
ac34dab verified
import gradio as gr
import torch
import torch.nn as nn
import numpy as np
from PIL import Image, ImageFilter, ImageEnhance, ImageOps
import math, os, traceback
# ─────────────────────────────────────────────────────────────────
# MODEL
# ─────────────────────────────────────────────────────────────────
class SinusoidalPositionEmbeddings(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, time):
device = time.device
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
emb = time[:, None] * emb[None, :]
return torch.cat((emb.sin(), emb.cos()), dim=-1)
class Block(nn.Module):
def __init__(self, in_ch, out_ch, time_emb_dim, up=False):
super().__init__()
self.time_mlp = nn.Linear(time_emb_dim, out_ch)
if up:
self.conv1 = nn.Conv2d(2 * in_ch, out_ch, 3, padding=1)
self.transform = nn.ConvTranspose2d(out_ch, out_ch, 4, 2, 1)
else:
self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
self.transform = nn.Conv2d(out_ch, out_ch, 4, 2, 1)
self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
self.bnorm1 = nn.BatchNorm2d(out_ch)
self.bnorm2 = nn.BatchNorm2d(out_ch)
self.relu = nn.ReLU()
def forward(self, x, t):
h = self.bnorm1(self.relu(self.conv1(x)))
t_emb = self.relu(self.time_mlp(t))[(...,) + (None,) * 2]
h = self.bnorm2(self.relu(self.conv2(h + t_emb)))
return self.transform(h)
class SimpleUnet(nn.Module):
def __init__(self, image_channels=1,
down_channels=(64,128,256,512,1024),
time_emb_dim=32, out_dim=1):
super().__init__()
up_channels = tuple(reversed(down_channels))
self.time_mlp = nn.Sequential(
SinusoidalPositionEmbeddings(time_emb_dim),
nn.Linear(time_emb_dim, time_emb_dim),
nn.ReLU()
)
self.conv0 = nn.Conv2d(image_channels, down_channels[0], 3, padding=1)
self.downs = nn.ModuleList([
Block(down_channels[i], down_channels[i+1], time_emb_dim)
for i in range(len(down_channels)-1)
])
self.ups = nn.ModuleList([
Block(up_channels[i], up_channels[i+1], time_emb_dim, up=True)
for i in range(len(up_channels)-1)
])
self.output = nn.Conv2d(up_channels[-1], out_dim, 1)
def forward(self, x, timestep):
t = self.time_mlp(timestep)
x = self.conv0(x)
skips = []
for down in self.downs:
x = down(x, t)
skips.append(x)
for up in self.ups:
x = torch.cat((x, skips.pop()), dim=1)
x = up(x, t)
return self.output(x)
# ─────────────────────────────────────────────────────────────────
# AUTO-DETECT ARCH FROM CHECKPOINT
# ─────────────────────────────────────────────────────────────────
def detect_arch(sd):
time_emb_dim = sd.get("time_mlp.1.weight", torch.zeros(32, 32)).shape[0]
image_channels = sd.get("conv0.weight", torch.zeros(1, 1, 1, 1)).shape[1]
n_down = sum(1 for k in sd if k.startswith("downs.") and k.endswith(".conv1.weight"))
n_down = n_down or 4
down_channels = [sd.get("conv0.weight", torch.zeros(64,1,1,1)).shape[0]]
for i in range(n_down):
key = f"downs.{i}.conv1.weight"
down_channels.append(sd.get(key, torch.zeros(down_channels[-1]*2,1,1,1)).shape[0])
return dict(image_channels=image_channels,
down_channels=tuple(down_channels),
time_emb_dim=time_emb_dim)
# ─────────────────────────────────────────────────────────────────
# DIFFUSION SCHEDULE (pre-compute everything once)
# ─────────────────────────────────────────────────────────────────
T = 300
betas = torch.linspace(0.0001, 0.02, T)
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
alphas_cumprod_prev = torch.cat([torch.tensor([1.0]), alphas_cumprod[:-1]])
sqrt_alphas_cumprod = alphas_cumprod.sqrt()
sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod).sqrt()
sqrt_recip_alphas = (1.0 / alphas).sqrt()
# posterior variance q(x_{t-1}|x_t, x_0)
posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
posterior_variance = posterior_variance.clamp(min=1e-20)
posterior_log_variance = posterior_variance.log()
posterior_mean_coef1 = betas * alphas_cumprod_prev.sqrt() / (1.0 - alphas_cumprod)
posterior_mean_coef2 = (1.0 - alphas_cumprod_prev) * alphas.sqrt() / (1.0 - alphas_cumprod)
def _g(vals, t, x_shape, device):
"""Gather scalar schedule value for batch index t, broadcast to x_shape."""
out = vals.gather(-1, t.cpu()).to(device)
return out.reshape(t.shape[0], *((1,) * (len(x_shape) - 1)))
# ─────────────────────────────────────────────────────────────────
# LOAD MODEL
# ─────────────────────────────────────────────────────────────────
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MODEL_PATH = "ddpm_model.pth"
model = None
load_error = ""
def load_model():
global model, load_error
if not os.path.exists(MODEL_PATH):
load_error = f"ddpm_model.pth not found at: {os.path.abspath(MODEL_PATH)}"
return False
try:
raw = torch.load(MODEL_PATH, map_location="cpu", weights_only=False)
sd = raw.get("model_state_dict", raw) if isinstance(raw, dict) else raw.state_dict()
arch = detect_arch(sd)
m = SimpleUnet(**arch).to(DEVICE)
missing, unexpected = m.load_state_dict(sd, strict=False)
m.eval()
model = m
load_error = f"Loaded βœ… (missing={len(missing)}, unexpected={len(unexpected)})" if (missing or unexpected) else ""
return True
except Exception as e:
load_error = f"Load error: {e}"
return False
model_loaded = load_model()
# ─────────────────────────────────────────────────────────────────
# SAMPLERS
# ─────────────────────────────────────────────────────────────────
@torch.no_grad()
def _predict_x0(x_t, t_val):
"""Predict clean x0 from noisy x_t at timestep t."""
t = torch.full((1,), t_val, device=DEVICE, dtype=torch.long)
sac = _g(sqrt_alphas_cumprod, t, x_t.shape, DEVICE)
somac = _g(sqrt_one_minus_alphas_cumprod, t, x_t.shape, DEVICE)
eps = model(x_t, t)
x0_hat = (x_t - somac * eps) / sac # Tweedie / rearranged forward eq.
return x0_hat.clamp(-1.0, 1.0), eps
@torch.no_grad()
def ddpm_step(x_t, t_val):
"""
Correct DDPM reverse step: q(x_{t-1} | x_t, x_0_hat)
β€” NO skipping, runs every single timestep.
"""
t = torch.full((1,), t_val, device=DEVICE, dtype=torch.long)
x0, _ = _predict_x0(x_t, t_val)
c1 = _g(posterior_mean_coef1, t, x_t.shape, DEVICE)
c2 = _g(posterior_mean_coef2, t, x_t.shape, DEVICE)
mean = c1 * x0 + c2 * x_t
if t_val == 0:
return mean
log_var = _g(posterior_log_variance, t, x_t.shape, DEVICE)
noise = torch.randn_like(x_t)
return mean + (0.5 * log_var).exp() * noise
@torch.no_grad()
def ddim_step(x_t, t_val, t_prev, eta=0.0):
"""
DDIM deterministic step (eta=0) or stochastic (eta>0).
Allows large timestep skips while maintaining quality.
"""
t = torch.full((1,), t_val, device=DEVICE, dtype=torch.long)
ac_t = _g(alphas_cumprod, t, x_t.shape, DEVICE) # tensor
somac_t = _g(sqrt_one_minus_alphas_cumprod, t, x_t.shape, DEVICE)
# ac_prev as a tensor broadcastable to x_t
ac_prev_val = alphas_cumprod[t_prev].item() if t_prev >= 0 else 1.0
ac_prev = torch.tensor(ac_prev_val, device=DEVICE, dtype=ac_t.dtype)
eps = model(x_t, t)
x0_hat = ((x_t - somac_t * eps) / ac_t.sqrt()).clamp(-1.0, 1.0)
if t_prev >= 0 and eta > 0.0:
sigma = eta * ((1.0 - ac_prev) / (1.0 - ac_t) * (1.0 - ac_t / ac_prev)).clamp(min=0).sqrt()
else:
sigma = torch.zeros(1, device=DEVICE)
dir_xt = (1.0 - ac_prev - sigma**2).clamp(min=0.0).sqrt() * eps
noise = sigma * torch.randn_like(x_t)
return ac_prev.sqrt() * x0_hat + dir_xt + noise
# ─────────────────────────────────────────────────────────────────
# FORWARD NOISE
# ─────────────────────────────────────────────────────────────────
def add_noise(tensor, t_val):
"""q(x_t | x_0) β€” closed-form forward process."""
t = torch.full((1,), t_val, device=DEVICE, dtype=torch.long)
sac = _g(sqrt_alphas_cumprod, t, tensor.shape, DEVICE)
somac = _g(sqrt_one_minus_alphas_cumprod, t, tensor.shape, DEVICE)
noise = torch.randn_like(tensor)
return (sac * tensor + somac * noise).clamp(-1.0, 1.0), noise
# ─────────────────────────────────────────────────────────────────
# IMAGE HELPERS
# ─────────────────────────────────────────────────────────────────
def preprocess(pil_img, sz, brightness, contrast, blur, sharpen, invert, equalize):
img = pil_img.convert("L").resize((sz, sz), Image.LANCZOS)
if brightness != 1.0:
img = ImageEnhance.Brightness(img).enhance(brightness)
if contrast != 1.0:
img = ImageEnhance.Contrast(img).enhance(contrast)
if blur > 0:
img = img.filter(ImageFilter.GaussianBlur(radius=blur))
if sharpen > 0:
img = img.filter(ImageFilter.UnsharpMask(radius=2, percent=int(sharpen*150), threshold=3))
if invert:
img = ImageOps.invert(img)
if equalize:
img = ImageOps.equalize(img)
return img
def to_tensor(img):
arr = np.array(img).astype(np.float32) / 255.0
return (torch.from_numpy(arr).unsqueeze(0).unsqueeze(0) * 2 - 1).to(DEVICE)
def to_pil(t):
arr = ((t[0, 0].cpu().float().numpy() + 1) / 2 * 255).clip(0, 255).astype(np.uint8)
return Image.fromarray(arr, "L").convert("RGB")
# ─────────────────────────────────────────────────────────────────
# GENERATE FROM NOISE
# ─────────────────────────────────────────────────────────────────
def generate_image(num_steps, image_size, seed, snap_count, sampler, eta):
if not model_loaded:
return None, [], f"⚠️ {load_error}"
torch.manual_seed(int(seed))
sz = int(image_size)
x = torch.randn((1, 1, sz, sz), device=DEVICE)
steps = min(int(num_steps), T)
if sampler == "DDIM":
seq = list(range(0, T, max(1, T // steps)))[::-1]
seq_prev = [-1] + seq[:-1]
snaps, every = [], max(1, len(seq) // int(snap_count))
for idx, (t_cur, t_prev) in enumerate(zip(seq, seq_prev)):
x = ddim_step(x, t_cur, t_prev, eta=float(eta))
if idx % every == 0 or idx == len(seq)-1:
snaps.append(to_pil(x.clamp(-1,1)))
else: # DDPM β€” every step
seq = list(reversed(range(steps)))
snaps, every = [], max(1, steps // int(snap_count))
for idx, t_val in enumerate(seq):
x = ddpm_step(x, t_val)
if idx % every == 0 or idx == len(seq)-1:
snaps.append(to_pil(x.clamp(-1,1)))
return to_pil(x.clamp(-1,1)), snaps, f"βœ… Done ({sampler}, {steps} steps)"
# ─────────────────────────────────────────────────────────────────
# DENOISE UPLOADED IMAGE β€” FIXED
# ─────────────────────────────────────────────────────────────────
def denoise_image(uploaded, noise_level, num_steps, seed, sampler, eta,
img_size, brightness, contrast, blur, sharpen, invert, equalize):
if not model_loaded:
return None, None, None, None, f"⚠️ {load_error}"
if uploaded is None:
return None, None, None, None, "⚠️ Please upload an image."
torch.manual_seed(int(seed))
sz = int(img_size)
prc = preprocess(Image.fromarray(uploaded), sz,
brightness, contrast, blur, sharpen, invert, equalize)
pre_rgb = prc.convert("RGB")
x0 = to_tensor(prc) # clean image in [-1,1]
# ── correct forward process ──────────────────────────
t_val = max(1, min(T - 1, int(float(noise_level) * (T - 1))))
x_noisy, _ = add_noise(x0, t_val)
noisy_pil = to_pil(x_noisy)
# ── reconstruction ───────────────────────────────────
x = x_noisy.clone()
if sampler == "DDIM":
# Build a sub-sequence from t_val β†’ 0 with num_steps steps
n = min(int(num_steps), t_val + 1)
seq = list(range(0, t_val + 1, max(1, (t_val + 1) // n)))
if seq[-1] != t_val:
seq.append(t_val)
seq = seq[::-1] # high β†’ low
seq_prev = seq[1:] + [-1] # shifted by one
for t_cur, t_prev in zip(seq, seq_prev):
x = ddim_step(x, t_cur, t_prev, eta=float(eta))
else:
# DDPM: must step EVERY timestep from t_val down to 0 β€” no skipping
for t_val_i in range(t_val, -1, -1):
x = ddpm_step(x, t_val_i)
recon_pil = to_pil(x.clamp(-1, 1))
# ── predicted x0 directly (fast single-step estimate) ──
x0_direct, _ = _predict_x0(x_noisy, t_val)
direct_pil = to_pil(x0_direct)
return (pre_rgb, noisy_pil, direct_pil, recon_pil,
f"βœ… Done ({sampler}, t={t_val}, steps={'all' if sampler=='DDPM' else num_steps})")
def show_artifact(which):
m = {"πŸ“‰ Loss Curve": "loss_plot.png",
"πŸ–ΌοΈ Reconstruction": "reconstruction.png",
"πŸ”„ Reverse Steps": "reverse_steps.png"}
path = m.get(which, "")
if path and os.path.exists(path):
return Image.open(path), f"βœ… {path}"
return None, f"⚠️ Not found: {path}"
# ─────────────────────────────────────────────────────────────────
# CSS
# ─────────────────────────────────────────────────────────────────
CSS = """
@import url('https://fonts.googleapis.com/css2?family=Space+Mono:wght@400;700&family=Syne:wght@400;600;700;800&display=swap');
:root {
--bg:#07071a; --card:#0c0c22; --inp:#10102a;
--acc:#a78bfa; --grn:#34d399; --cyn:#38bdf8; --glow:#7c3aed;
--txt:#e2e8f0; --mut:#64748b; --bdr:rgba(167,139,250,.14); --r:12px;
}
body,.gradio-container{background:var(--bg)!important;font-family:'Syne',sans-serif!important;color:var(--txt)!important;}
/* hero */
#hero{background:linear-gradient(160deg,#0b0629,#120b36,#071528);border:1px solid var(--bdr);border-radius:var(--r);padding:2.6rem 2rem 2rem;text-align:center;position:relative;overflow:hidden;margin-bottom:1.4rem;}
#hero::before{content:'';position:absolute;inset:0;background:radial-gradient(ellipse 70% 55% at 50% 0%,rgba(124,58,237,.3),transparent 68%);}
#hero::after{content:'';position:absolute;inset:0;background:url("data:image/svg+xml,%3Csvg width='60' height='60' viewBox='0 0 60 60' xmlns='http://www.w3.org/2000/svg'%3E%3Cg fill='%239C92AC' fill-opacity='0.03'%3E%3Cpath d='M36 34v-4h-2v4h-4v2h4v4h2v-4h4v-2h-4zm0-30V0h-2v4h-4v2h4v4h2V6h4V4h-4zM6 34v-4H4v4H0v2h4v4h2v-4h4v-2H6zM6 4V0H4v4H0v2h4v4h2V6h4V4H6z'/%3E%3C/g%3E%3C/svg%3E");}
.htitle{font-size:2.9rem;font-weight:800;letter-spacing:-1.5px;background:linear-gradient(100deg,var(--acc),var(--cyn) 55%,var(--grn));-webkit-background-clip:text;-webkit-text-fill-color:transparent;margin:0 0 .3rem;position:relative;}
.hsub{font-family:'Space Mono',monospace;font-size:.87rem;color:var(--mut);margin:0;position:relative;}
.badge{display:inline-flex;align-items:center;gap:.4rem;margin-top:.8rem;padding:.28rem 1rem;border-radius:999px;font-family:'Space Mono',monospace;font-size:.74rem;font-weight:700;border:1px solid rgba(167,139,250,.3);background:rgba(167,139,250,.08);color:var(--acc);position:relative;}
.badge.ok{border-color:rgba(52,211,153,.35);background:rgba(52,211,153,.08);color:var(--grn);}
.badge.err{border-color:rgba(248,113,113,.35);background:rgba(248,113,113,.08);color:#f87171;}
/* tabs */
.tab-nav{border-bottom:1px solid var(--bdr)!important;}
.tab-nav button{font-family:'Syne',sans-serif!important;font-weight:600!important;font-size:.88rem!important;color:var(--mut)!important;background:transparent!important;border:none!important;border-bottom:2px solid transparent!important;padding:.7rem 1.4rem!important;transition:all .2s!important;}
.tab-nav button.selected{color:var(--acc)!important;border-bottom-color:var(--acc)!important;}
/* inputs */
input[type=number],input[type=text],textarea{background:var(--inp)!important;border:1px solid var(--bdr)!important;color:var(--txt)!important;border-radius:8px!important;font-family:'Space Mono',monospace!important;font-size:.82rem!important;}
input[type=range]{accent-color:var(--acc)!important;}
label span,.label-wrap span{font-family:'Space Mono',monospace!important;font-size:.72rem!important;color:var(--mut)!important;text-transform:uppercase;letter-spacing:.07em;}
.gr-check-radio{accent-color:var(--acc)!important;}
/* buttons */
button.primary,.gr-button-primary{background:linear-gradient(135deg,var(--glow),#0e7490)!important;color:#fff!important;border:none!important;border-radius:9px!important;font-family:'Syne',sans-serif!important;font-weight:700!important;font-size:.94rem!important;padding:.65rem 2rem!important;box-shadow:0 0 20px rgba(124,58,237,.4)!important;transition:all .2s!important;width:100%;}
button.primary:hover{box-shadow:0 0 36px rgba(124,58,237,.65)!important;transform:translateY(-2px)!important;}
/* section label */
.sl{font-family:'Space Mono',monospace;font-size:.7rem;font-weight:700;letter-spacing:.1em;text-transform:uppercase;color:var(--mut);padding-bottom:.4rem;border-bottom:1px solid var(--bdr);margin-bottom:.7rem;}
/* tip box */
.tip{background:rgba(56,189,248,.06);border:1px solid rgba(56,189,248,.18);border-radius:8px;padding:.7rem 1rem;font-family:'Space Mono',monospace;font-size:.78rem;color:var(--cyn);margin-top:.5rem;}
/* images */
.gr-image img,.output-image img{border-radius:10px!important;border:1px solid var(--bdr)!important;}
.gallery-item{border-radius:8px!important;}
.gr-textbox textarea{font-family:'Space Mono',monospace!important;font-size:.8rem!important;}
#footer{text-align:center;padding:1.2rem;font-family:'Space Mono',monospace;font-size:.74rem;color:var(--mut);margin-top:1rem;border-top:1px solid var(--bdr);}
"""
bc = "ok" if model_loaded else "err"
bi = "🟒" if model_loaded else "πŸ”΄"
bm = "Model Loaded Β· DDPM Ready" if model_loaded else f"Model Not Found β€” place ddpm_model.pth next to app.py"
# ─────────────────────────────────────────────────────────────────
# UI
# ─────────────────────────────────────────────────────────────────
with gr.Blocks(title="Noise2Vision β€” DDPM") as demo:
gr.HTML(f"""
<div id="hero">
<div class="htitle">⚑ Noise2Vision</div>
<p class="hsub">Denoising Diffusion Probabilistic Model &nbsp;Β·&nbsp; Reverse the noise, reveal the signal</p>
<div class="badge {bc}">{bi} {bm}</div>
</div>
""")
with gr.Tabs():
# ══ GENERATE ══════════════════════════════════════════════
with gr.Tab("🎲 Generate"):
gr.Markdown("#### Generate a new image from pure Gaussian noise via reverse diffusion.")
with gr.Row(equal_height=False):
with gr.Column(scale=1, min_width=270):
gr.HTML('<div class="sl">βš™οΈ Diffusion Controls</div>')
g_sampler = gr.Radio(["DDPM","DDIM"], value="DDIM", label="Sampler")
g_steps = gr.Slider(10, 300, value=100, step=10, label="Steps")
g_size = gr.Slider(32, 128, value=64, step=32, label="Output Size (px)")
g_seed = gr.Number(value=42, label="Random Seed", precision=0)
g_snap = gr.Slider(4, 16, value=8, step=2, label="Snapshots")
g_eta = gr.Slider(0.0, 1.0, value=0.0, step=0.1, label="DDIM Ξ· (0=deterministic, 1=DDPM-like)")
gen_btn = gr.Button("✦ Generate", variant="primary")
gen_status = gr.Textbox(label="Status", interactive=False, lines=2)
with gr.Column(scale=2):
gr.HTML('<div class="sl">πŸ–ΌοΈ Result</div>')
gen_out = gr.Image(label="Generated", type="pil", height=300)
gr.HTML('<div class="sl" style="margin-top:.8rem">🎞️ Snapshots</div>')
gen_gallery = gr.Gallery(label="", columns=8, height=120, allow_preview=True)
gen_btn.click(generate_image,
[g_steps, g_size, g_seed, g_snap, g_sampler, g_eta],
[gen_out, gen_gallery, gen_status])
# ══ DENOISE ════════════════════════════════════════════════
with gr.Tab("πŸ”¬ Denoise Upload"):
gr.Markdown("#### Upload β†’ preprocess β†’ add noise β†’ reconstruct. Four-stage pipeline.")
with gr.Row(equal_height=False):
with gr.Column(scale=1, min_width=295):
gr.HTML('<div class="sl">πŸ“‚ Image</div>')
up_img = gr.Image(label="Upload Image", type="numpy", height=175)
gr.HTML('<div class="sl" style="margin-top:.9rem">🎨 Preprocessing</div>')
img_size = gr.Slider(32, 128, value=64, step=32, label="Resize (px)")
with gr.Row():
brightness = gr.Slider(0.3, 2.5, value=1.0, step=0.1, label="Brightness")
contrast = gr.Slider(0.3, 2.5, value=1.0, step=0.1, label="Contrast")
with gr.Row():
blur_s = gr.Slider(0.0, 5.0, value=0.0, step=0.5, label="Blur")
sharpen = gr.Slider(0.0, 1.0, value=0.0, step=0.1, label="Sharpen")
with gr.Row():
invert_c = gr.Checkbox(label="πŸ”„ Invert", value=False)
equalize_c = gr.Checkbox(label="πŸ“Š Equalize", value=False)
gr.HTML('<div class="sl" style="margin-top:.9rem">βš™οΈ Diffusion</div>')
d_sampler = gr.Radio(["DDPM","DDIM"], value="DDIM", label="Sampler")
noise_lvl = gr.Slider(0.05, 0.99, value=0.5, step=0.01, label="Noise Level (t/T)")
den_steps = gr.Slider(10, 200, value=50, step=5, label="DDIM Steps (ignored for DDPM)")
den_eta = gr.Slider(0.0, 1.0, value=0.0, step=0.1, label="DDIM Ξ·")
den_seed = gr.Number(value=0, label="Seed", precision=0)
gr.HTML('<div class="tip">πŸ’‘ <b>Tip:</b> DDIM + 50 steps gives sharp, fast reconstruction.<br>DDPM runs all tβ†’0 steps (slow but theoretically exact).</div>')
den_btn = gr.Button("✦ Reconstruct", variant="primary")
den_status = gr.Textbox(label="Status", interactive=False, lines=2)
with gr.Column(scale=2):
gr.HTML('<div class="sl">πŸ“Š 4-Stage Pipeline</div>')
with gr.Row():
pre_out = gr.Image(label="β‘  Preprocessed", type="pil", height=200)
noisy_out = gr.Image(label="β‘‘ Noisy (t)", type="pil", height=200)
with gr.Row():
direct_out = gr.Image(label="β‘’ Direct xβ‚€ Estimate (1-step)", type="pil", height=200)
recon_out = gr.Image(label="β‘£ Full Reconstruction", type="pil", height=200)
gr.HTML('<div class="tip">β‘’ is a fast single-step prediction. β‘£ is the full iterative reverse result.</div>')
den_btn.click(
denoise_image,
[up_img, noise_lvl, den_steps, den_seed, d_sampler, den_eta,
img_size, brightness, contrast, blur_s, sharpen, invert_c, equalize_c],
[pre_out, noisy_out, direct_out, recon_out, den_status]
)
# ══ ARTIFACTS ══════════════════════════════════════════════
with gr.Tab("πŸ“Š Training Artifacts"):
gr.Markdown("#### View saved training outputs.")
with gr.Row():
art_radio = gr.Radio(
["πŸ“‰ Loss Curve","πŸ–ΌοΈ Reconstruction","πŸ”„ Reverse Steps"],
value="πŸ“‰ Loss Curve", label="Select"
)
view_btn = gr.Button("View β†’", variant="primary", scale=0)
art_status = gr.Textbox(label="", interactive=False, lines=1)
art_out = gr.Image(label="", type="pil", height=460)
view_btn.click(show_artifact, [art_radio], [art_out, art_status])
# ══ ABOUT ══════════════════════════════════════════════════
with gr.Tab("ℹ️ About"):
gr.Markdown(f"""
## Noise2Vision β€” DDPM
### What changed in reconstruction
| Old (broken) | New (fixed) |
|---|---|
| Skipped timesteps with `stride` | DDPM runs **every** step t→0 |
| Wrong posterior: used `betas * pred / sqrt_omac` | Correct `q(x_{{t-1}}|x_t,xΜ‚_0)` posterior mean |
| Clamped intermediate latents | Only clamp final output |
| No DDIM | **DDIM** added β€” 50 steps β‰ˆ quality of 300 DDPM steps |
| Single output | **4-stage** output: preprocessed β†’ noisy β†’ direct xΜ‚β‚€ β†’ iterative recon |
### Architecture (auto-detected)
| Component | Detail |
|---|---|
| Backbone | U-Net + sinusoidal time embeddings |
| Encoder | 64β†’128β†’256β†’512β†’1024 |
| Decoder | 1024β†’512β†’256β†’128β†’64 |
| T | 300 Β· Linear Ξ² schedule 0.0001β†’0.02 |
### Files
`ddpm_model.pth` Β· `state.db` Β· `loss_plot.png` Β· `reconstruction.png` Β· `reverse_steps.png`
Model path: `{os.path.abspath(MODEL_PATH)}` β€” {"βœ… Found" if os.path.exists(MODEL_PATH) else "❌ Not found"}
---
*Noise2Vision Β· AsadAnalyst Β· Hugging Face Spaces*
""")
gr.HTML('<div id="footer">Noise2Vision &nbsp;Β·&nbsp; DDPM + DDIM &nbsp;Β·&nbsp; Gradio &nbsp;Β·&nbsp; AsadAnalyst</div>')
if __name__ == "__main__":
demo.launch(css=CSS, theme=gr.themes.Base())