import gradio as gr import torch import torch.nn as nn import numpy as np from PIL import Image, ImageFilter, ImageEnhance, ImageOps import math, os, traceback # ───────────────────────────────────────────────────────────────── # MODEL # ───────────────────────────────────────────────────────────────── class SinusoidalPositionEmbeddings(nn.Module): def __init__(self, dim): super().__init__() self.dim = dim def forward(self, time): device = time.device half_dim = self.dim // 2 emb = math.log(10000) / (half_dim - 1) emb = torch.exp(torch.arange(half_dim, device=device) * -emb) emb = time[:, None] * emb[None, :] return torch.cat((emb.sin(), emb.cos()), dim=-1) class Block(nn.Module): def __init__(self, in_ch, out_ch, time_emb_dim, up=False): super().__init__() self.time_mlp = nn.Linear(time_emb_dim, out_ch) if up: self.conv1 = nn.Conv2d(2 * in_ch, out_ch, 3, padding=1) self.transform = nn.ConvTranspose2d(out_ch, out_ch, 4, 2, 1) else: self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1) self.transform = nn.Conv2d(out_ch, out_ch, 4, 2, 1) self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1) self.bnorm1 = nn.BatchNorm2d(out_ch) self.bnorm2 = nn.BatchNorm2d(out_ch) self.relu = nn.ReLU() def forward(self, x, t): h = self.bnorm1(self.relu(self.conv1(x))) t_emb = self.relu(self.time_mlp(t))[(...,) + (None,) * 2] h = self.bnorm2(self.relu(self.conv2(h + t_emb))) return self.transform(h) class SimpleUnet(nn.Module): def __init__(self, image_channels=1, down_channels=(64,128,256,512,1024), time_emb_dim=32, out_dim=1): super().__init__() up_channels = tuple(reversed(down_channels)) self.time_mlp = nn.Sequential( SinusoidalPositionEmbeddings(time_emb_dim), nn.Linear(time_emb_dim, time_emb_dim), nn.ReLU() ) self.conv0 = nn.Conv2d(image_channels, down_channels[0], 3, padding=1) self.downs = nn.ModuleList([ Block(down_channels[i], down_channels[i+1], time_emb_dim) for i in range(len(down_channels)-1) ]) self.ups = nn.ModuleList([ Block(up_channels[i], up_channels[i+1], time_emb_dim, up=True) for i in range(len(up_channels)-1) ]) self.output = nn.Conv2d(up_channels[-1], out_dim, 1) def forward(self, x, timestep): t = self.time_mlp(timestep) x = self.conv0(x) skips = [] for down in self.downs: x = down(x, t) skips.append(x) for up in self.ups: x = torch.cat((x, skips.pop()), dim=1) x = up(x, t) return self.output(x) # ───────────────────────────────────────────────────────────────── # AUTO-DETECT ARCH FROM CHECKPOINT # ───────────────────────────────────────────────────────────────── def detect_arch(sd): time_emb_dim = sd.get("time_mlp.1.weight", torch.zeros(32, 32)).shape[0] image_channels = sd.get("conv0.weight", torch.zeros(1, 1, 1, 1)).shape[1] n_down = sum(1 for k in sd if k.startswith("downs.") and k.endswith(".conv1.weight")) n_down = n_down or 4 down_channels = [sd.get("conv0.weight", torch.zeros(64,1,1,1)).shape[0]] for i in range(n_down): key = f"downs.{i}.conv1.weight" down_channels.append(sd.get(key, torch.zeros(down_channels[-1]*2,1,1,1)).shape[0]) return dict(image_channels=image_channels, down_channels=tuple(down_channels), time_emb_dim=time_emb_dim) # ───────────────────────────────────────────────────────────────── # DIFFUSION SCHEDULE (pre-compute everything once) # ───────────────────────────────────────────────────────────────── T = 300 betas = torch.linspace(0.0001, 0.02, T) alphas = 1.0 - betas alphas_cumprod = torch.cumprod(alphas, dim=0) alphas_cumprod_prev = torch.cat([torch.tensor([1.0]), alphas_cumprod[:-1]]) sqrt_alphas_cumprod = alphas_cumprod.sqrt() sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod).sqrt() sqrt_recip_alphas = (1.0 / alphas).sqrt() # posterior variance q(x_{t-1}|x_t, x_0) posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod) posterior_variance = posterior_variance.clamp(min=1e-20) posterior_log_variance = posterior_variance.log() posterior_mean_coef1 = betas * alphas_cumprod_prev.sqrt() / (1.0 - alphas_cumprod) posterior_mean_coef2 = (1.0 - alphas_cumprod_prev) * alphas.sqrt() / (1.0 - alphas_cumprod) def _g(vals, t, x_shape, device): """Gather scalar schedule value for batch index t, broadcast to x_shape.""" out = vals.gather(-1, t.cpu()).to(device) return out.reshape(t.shape[0], *((1,) * (len(x_shape) - 1))) # ───────────────────────────────────────────────────────────────── # LOAD MODEL # ───────────────────────────────────────────────────────────────── DEVICE = "cuda" if torch.cuda.is_available() else "cpu" MODEL_PATH = "ddpm_model.pth" model = None load_error = "" def load_model(): global model, load_error if not os.path.exists(MODEL_PATH): load_error = f"ddpm_model.pth not found at: {os.path.abspath(MODEL_PATH)}" return False try: raw = torch.load(MODEL_PATH, map_location="cpu", weights_only=False) sd = raw.get("model_state_dict", raw) if isinstance(raw, dict) else raw.state_dict() arch = detect_arch(sd) m = SimpleUnet(**arch).to(DEVICE) missing, unexpected = m.load_state_dict(sd, strict=False) m.eval() model = m load_error = f"Loaded ✅ (missing={len(missing)}, unexpected={len(unexpected)})" if (missing or unexpected) else "" return True except Exception as e: load_error = f"Load error: {e}" return False model_loaded = load_model() # ───────────────────────────────────────────────────────────────── # SAMPLERS # ───────────────────────────────────────────────────────────────── @torch.no_grad() def _predict_x0(x_t, t_val): """Predict clean x0 from noisy x_t at timestep t.""" t = torch.full((1,), t_val, device=DEVICE, dtype=torch.long) sac = _g(sqrt_alphas_cumprod, t, x_t.shape, DEVICE) somac = _g(sqrt_one_minus_alphas_cumprod, t, x_t.shape, DEVICE) eps = model(x_t, t) x0_hat = (x_t - somac * eps) / sac # Tweedie / rearranged forward eq. return x0_hat.clamp(-1.0, 1.0), eps @torch.no_grad() def ddpm_step(x_t, t_val): """ Correct DDPM reverse step: q(x_{t-1} | x_t, x_0_hat) — NO skipping, runs every single timestep. """ t = torch.full((1,), t_val, device=DEVICE, dtype=torch.long) x0, _ = _predict_x0(x_t, t_val) c1 = _g(posterior_mean_coef1, t, x_t.shape, DEVICE) c2 = _g(posterior_mean_coef2, t, x_t.shape, DEVICE) mean = c1 * x0 + c2 * x_t if t_val == 0: return mean log_var = _g(posterior_log_variance, t, x_t.shape, DEVICE) noise = torch.randn_like(x_t) return mean + (0.5 * log_var).exp() * noise @torch.no_grad() def ddim_step(x_t, t_val, t_prev, eta=0.0): """ DDIM deterministic step (eta=0) or stochastic (eta>0). Allows large timestep skips while maintaining quality. """ t = torch.full((1,), t_val, device=DEVICE, dtype=torch.long) ac_t = _g(alphas_cumprod, t, x_t.shape, DEVICE) # tensor somac_t = _g(sqrt_one_minus_alphas_cumprod, t, x_t.shape, DEVICE) # ac_prev as a tensor broadcastable to x_t ac_prev_val = alphas_cumprod[t_prev].item() if t_prev >= 0 else 1.0 ac_prev = torch.tensor(ac_prev_val, device=DEVICE, dtype=ac_t.dtype) eps = model(x_t, t) x0_hat = ((x_t - somac_t * eps) / ac_t.sqrt()).clamp(-1.0, 1.0) if t_prev >= 0 and eta > 0.0: sigma = eta * ((1.0 - ac_prev) / (1.0 - ac_t) * (1.0 - ac_t / ac_prev)).clamp(min=0).sqrt() else: sigma = torch.zeros(1, device=DEVICE) dir_xt = (1.0 - ac_prev - sigma**2).clamp(min=0.0).sqrt() * eps noise = sigma * torch.randn_like(x_t) return ac_prev.sqrt() * x0_hat + dir_xt + noise # ───────────────────────────────────────────────────────────────── # FORWARD NOISE # ───────────────────────────────────────────────────────────────── def add_noise(tensor, t_val): """q(x_t | x_0) — closed-form forward process.""" t = torch.full((1,), t_val, device=DEVICE, dtype=torch.long) sac = _g(sqrt_alphas_cumprod, t, tensor.shape, DEVICE) somac = _g(sqrt_one_minus_alphas_cumprod, t, tensor.shape, DEVICE) noise = torch.randn_like(tensor) return (sac * tensor + somac * noise).clamp(-1.0, 1.0), noise # ───────────────────────────────────────────────────────────────── # IMAGE HELPERS # ───────────────────────────────────────────────────────────────── def preprocess(pil_img, sz, brightness, contrast, blur, sharpen, invert, equalize): img = pil_img.convert("L").resize((sz, sz), Image.LANCZOS) if brightness != 1.0: img = ImageEnhance.Brightness(img).enhance(brightness) if contrast != 1.0: img = ImageEnhance.Contrast(img).enhance(contrast) if blur > 0: img = img.filter(ImageFilter.GaussianBlur(radius=blur)) if sharpen > 0: img = img.filter(ImageFilter.UnsharpMask(radius=2, percent=int(sharpen*150), threshold=3)) if invert: img = ImageOps.invert(img) if equalize: img = ImageOps.equalize(img) return img def to_tensor(img): arr = np.array(img).astype(np.float32) / 255.0 return (torch.from_numpy(arr).unsqueeze(0).unsqueeze(0) * 2 - 1).to(DEVICE) def to_pil(t): arr = ((t[0, 0].cpu().float().numpy() + 1) / 2 * 255).clip(0, 255).astype(np.uint8) return Image.fromarray(arr, "L").convert("RGB") # ───────────────────────────────────────────────────────────────── # GENERATE FROM NOISE # ───────────────────────────────────────────────────────────────── def generate_image(num_steps, image_size, seed, snap_count, sampler, eta): if not model_loaded: return None, [], f"⚠️ {load_error}" torch.manual_seed(int(seed)) sz = int(image_size) x = torch.randn((1, 1, sz, sz), device=DEVICE) steps = min(int(num_steps), T) if sampler == "DDIM": seq = list(range(0, T, max(1, T // steps)))[::-1] seq_prev = [-1] + seq[:-1] snaps, every = [], max(1, len(seq) // int(snap_count)) for idx, (t_cur, t_prev) in enumerate(zip(seq, seq_prev)): x = ddim_step(x, t_cur, t_prev, eta=float(eta)) if idx % every == 0 or idx == len(seq)-1: snaps.append(to_pil(x.clamp(-1,1))) else: # DDPM — every step seq = list(reversed(range(steps))) snaps, every = [], max(1, steps // int(snap_count)) for idx, t_val in enumerate(seq): x = ddpm_step(x, t_val) if idx % every == 0 or idx == len(seq)-1: snaps.append(to_pil(x.clamp(-1,1))) return to_pil(x.clamp(-1,1)), snaps, f"✅ Done ({sampler}, {steps} steps)" # ───────────────────────────────────────────────────────────────── # DENOISE UPLOADED IMAGE — FIXED # ───────────────────────────────────────────────────────────────── def denoise_image(uploaded, noise_level, num_steps, seed, sampler, eta, img_size, brightness, contrast, blur, sharpen, invert, equalize): if not model_loaded: return None, None, None, None, f"⚠️ {load_error}" if uploaded is None: return None, None, None, None, "⚠️ Please upload an image." torch.manual_seed(int(seed)) sz = int(img_size) prc = preprocess(Image.fromarray(uploaded), sz, brightness, contrast, blur, sharpen, invert, equalize) pre_rgb = prc.convert("RGB") x0 = to_tensor(prc) # clean image in [-1,1] # ── correct forward process ────────────────────────── t_val = max(1, min(T - 1, int(float(noise_level) * (T - 1)))) x_noisy, _ = add_noise(x0, t_val) noisy_pil = to_pil(x_noisy) # ── reconstruction ─────────────────────────────────── x = x_noisy.clone() if sampler == "DDIM": # Build a sub-sequence from t_val → 0 with num_steps steps n = min(int(num_steps), t_val + 1) seq = list(range(0, t_val + 1, max(1, (t_val + 1) // n))) if seq[-1] != t_val: seq.append(t_val) seq = seq[::-1] # high → low seq_prev = seq[1:] + [-1] # shifted by one for t_cur, t_prev in zip(seq, seq_prev): x = ddim_step(x, t_cur, t_prev, eta=float(eta)) else: # DDPM: must step EVERY timestep from t_val down to 0 — no skipping for t_val_i in range(t_val, -1, -1): x = ddpm_step(x, t_val_i) recon_pil = to_pil(x.clamp(-1, 1)) # ── predicted x0 directly (fast single-step estimate) ── x0_direct, _ = _predict_x0(x_noisy, t_val) direct_pil = to_pil(x0_direct) return (pre_rgb, noisy_pil, direct_pil, recon_pil, f"✅ Done ({sampler}, t={t_val}, steps={'all' if sampler=='DDPM' else num_steps})") def show_artifact(which): m = {"📉 Loss Curve": "loss_plot.png", "🖼️ Reconstruction": "reconstruction.png", "🔄 Reverse Steps": "reverse_steps.png"} path = m.get(which, "") if path and os.path.exists(path): return Image.open(path), f"✅ {path}" return None, f"⚠️ Not found: {path}" # ───────────────────────────────────────────────────────────────── # CSS # ───────────────────────────────────────────────────────────────── CSS = """ @import url('https://fonts.googleapis.com/css2?family=Space+Mono:wght@400;700&family=Syne:wght@400;600;700;800&display=swap'); :root { --bg:#07071a; --card:#0c0c22; --inp:#10102a; --acc:#a78bfa; --grn:#34d399; --cyn:#38bdf8; --glow:#7c3aed; --txt:#e2e8f0; --mut:#64748b; --bdr:rgba(167,139,250,.14); --r:12px; } body,.gradio-container{background:var(--bg)!important;font-family:'Syne',sans-serif!important;color:var(--txt)!important;} /* hero */ #hero{background:linear-gradient(160deg,#0b0629,#120b36,#071528);border:1px solid var(--bdr);border-radius:var(--r);padding:2.6rem 2rem 2rem;text-align:center;position:relative;overflow:hidden;margin-bottom:1.4rem;} #hero::before{content:'';position:absolute;inset:0;background:radial-gradient(ellipse 70% 55% at 50% 0%,rgba(124,58,237,.3),transparent 68%);} #hero::after{content:'';position:absolute;inset:0;background:url("data:image/svg+xml,%3Csvg width='60' height='60' viewBox='0 0 60 60' xmlns='http://www.w3.org/2000/svg'%3E%3Cg fill='%239C92AC' fill-opacity='0.03'%3E%3Cpath d='M36 34v-4h-2v4h-4v2h4v4h2v-4h4v-2h-4zm0-30V0h-2v4h-4v2h4v4h2V6h4V4h-4zM6 34v-4H4v4H0v2h4v4h2v-4h4v-2H6zM6 4V0H4v4H0v2h4v4h2V6h4V4H6z'/%3E%3C/g%3E%3C/svg%3E");} .htitle{font-size:2.9rem;font-weight:800;letter-spacing:-1.5px;background:linear-gradient(100deg,var(--acc),var(--cyn) 55%,var(--grn));-webkit-background-clip:text;-webkit-text-fill-color:transparent;margin:0 0 .3rem;position:relative;} .hsub{font-family:'Space Mono',monospace;font-size:.87rem;color:var(--mut);margin:0;position:relative;} .badge{display:inline-flex;align-items:center;gap:.4rem;margin-top:.8rem;padding:.28rem 1rem;border-radius:999px;font-family:'Space Mono',monospace;font-size:.74rem;font-weight:700;border:1px solid rgba(167,139,250,.3);background:rgba(167,139,250,.08);color:var(--acc);position:relative;} .badge.ok{border-color:rgba(52,211,153,.35);background:rgba(52,211,153,.08);color:var(--grn);} .badge.err{border-color:rgba(248,113,113,.35);background:rgba(248,113,113,.08);color:#f87171;} /* tabs */ .tab-nav{border-bottom:1px solid var(--bdr)!important;} .tab-nav button{font-family:'Syne',sans-serif!important;font-weight:600!important;font-size:.88rem!important;color:var(--mut)!important;background:transparent!important;border:none!important;border-bottom:2px solid transparent!important;padding:.7rem 1.4rem!important;transition:all .2s!important;} .tab-nav button.selected{color:var(--acc)!important;border-bottom-color:var(--acc)!important;} /* inputs */ input[type=number],input[type=text],textarea{background:var(--inp)!important;border:1px solid var(--bdr)!important;color:var(--txt)!important;border-radius:8px!important;font-family:'Space Mono',monospace!important;font-size:.82rem!important;} input[type=range]{accent-color:var(--acc)!important;} label span,.label-wrap span{font-family:'Space Mono',monospace!important;font-size:.72rem!important;color:var(--mut)!important;text-transform:uppercase;letter-spacing:.07em;} .gr-check-radio{accent-color:var(--acc)!important;} /* buttons */ button.primary,.gr-button-primary{background:linear-gradient(135deg,var(--glow),#0e7490)!important;color:#fff!important;border:none!important;border-radius:9px!important;font-family:'Syne',sans-serif!important;font-weight:700!important;font-size:.94rem!important;padding:.65rem 2rem!important;box-shadow:0 0 20px rgba(124,58,237,.4)!important;transition:all .2s!important;width:100%;} button.primary:hover{box-shadow:0 0 36px rgba(124,58,237,.65)!important;transform:translateY(-2px)!important;} /* section label */ .sl{font-family:'Space Mono',monospace;font-size:.7rem;font-weight:700;letter-spacing:.1em;text-transform:uppercase;color:var(--mut);padding-bottom:.4rem;border-bottom:1px solid var(--bdr);margin-bottom:.7rem;} /* tip box */ .tip{background:rgba(56,189,248,.06);border:1px solid rgba(56,189,248,.18);border-radius:8px;padding:.7rem 1rem;font-family:'Space Mono',monospace;font-size:.78rem;color:var(--cyn);margin-top:.5rem;} /* images */ .gr-image img,.output-image img{border-radius:10px!important;border:1px solid var(--bdr)!important;} .gallery-item{border-radius:8px!important;} .gr-textbox textarea{font-family:'Space Mono',monospace!important;font-size:.8rem!important;} #footer{text-align:center;padding:1.2rem;font-family:'Space Mono',monospace;font-size:.74rem;color:var(--mut);margin-top:1rem;border-top:1px solid var(--bdr);} """ bc = "ok" if model_loaded else "err" bi = "🟢" if model_loaded else "🔴" bm = "Model Loaded · DDPM Ready" if model_loaded else f"Model Not Found — place ddpm_model.pth next to app.py" # ───────────────────────────────────────────────────────────────── # UI # ───────────────────────────────────────────────────────────────── with gr.Blocks(title="Noise2Vision — DDPM") as demo: gr.HTML(f"""
Denoising Diffusion Probabilistic Model · Reverse the noise, reveal the signal