Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import torch.nn as nn | |
| import numpy as np | |
| from PIL import Image, ImageFilter, ImageEnhance, ImageOps | |
| import math, os, traceback | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # MODEL | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class SinusoidalPositionEmbeddings(nn.Module): | |
| def __init__(self, dim): | |
| super().__init__() | |
| self.dim = dim | |
| def forward(self, time): | |
| device = time.device | |
| half_dim = self.dim // 2 | |
| emb = math.log(10000) / (half_dim - 1) | |
| emb = torch.exp(torch.arange(half_dim, device=device) * -emb) | |
| emb = time[:, None] * emb[None, :] | |
| return torch.cat((emb.sin(), emb.cos()), dim=-1) | |
| class Block(nn.Module): | |
| def __init__(self, in_ch, out_ch, time_emb_dim, up=False): | |
| super().__init__() | |
| self.time_mlp = nn.Linear(time_emb_dim, out_ch) | |
| if up: | |
| self.conv1 = nn.Conv2d(2 * in_ch, out_ch, 3, padding=1) | |
| self.transform = nn.ConvTranspose2d(out_ch, out_ch, 4, 2, 1) | |
| else: | |
| self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1) | |
| self.transform = nn.Conv2d(out_ch, out_ch, 4, 2, 1) | |
| self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1) | |
| self.bnorm1 = nn.BatchNorm2d(out_ch) | |
| self.bnorm2 = nn.BatchNorm2d(out_ch) | |
| self.relu = nn.ReLU() | |
| def forward(self, x, t): | |
| h = self.bnorm1(self.relu(self.conv1(x))) | |
| t_emb = self.relu(self.time_mlp(t))[(...,) + (None,) * 2] | |
| h = self.bnorm2(self.relu(self.conv2(h + t_emb))) | |
| return self.transform(h) | |
| class SimpleUnet(nn.Module): | |
| def __init__(self, image_channels=1, | |
| down_channels=(64,128,256,512,1024), | |
| time_emb_dim=32, out_dim=1): | |
| super().__init__() | |
| up_channels = tuple(reversed(down_channels)) | |
| self.time_mlp = nn.Sequential( | |
| SinusoidalPositionEmbeddings(time_emb_dim), | |
| nn.Linear(time_emb_dim, time_emb_dim), | |
| nn.ReLU() | |
| ) | |
| self.conv0 = nn.Conv2d(image_channels, down_channels[0], 3, padding=1) | |
| self.downs = nn.ModuleList([ | |
| Block(down_channels[i], down_channels[i+1], time_emb_dim) | |
| for i in range(len(down_channels)-1) | |
| ]) | |
| self.ups = nn.ModuleList([ | |
| Block(up_channels[i], up_channels[i+1], time_emb_dim, up=True) | |
| for i in range(len(up_channels)-1) | |
| ]) | |
| self.output = nn.Conv2d(up_channels[-1], out_dim, 1) | |
| def forward(self, x, timestep): | |
| t = self.time_mlp(timestep) | |
| x = self.conv0(x) | |
| skips = [] | |
| for down in self.downs: | |
| x = down(x, t) | |
| skips.append(x) | |
| for up in self.ups: | |
| x = torch.cat((x, skips.pop()), dim=1) | |
| x = up(x, t) | |
| return self.output(x) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # AUTO-DETECT ARCH FROM CHECKPOINT | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def detect_arch(sd): | |
| time_emb_dim = sd.get("time_mlp.1.weight", torch.zeros(32, 32)).shape[0] | |
| image_channels = sd.get("conv0.weight", torch.zeros(1, 1, 1, 1)).shape[1] | |
| n_down = sum(1 for k in sd if k.startswith("downs.") and k.endswith(".conv1.weight")) | |
| n_down = n_down or 4 | |
| down_channels = [sd.get("conv0.weight", torch.zeros(64,1,1,1)).shape[0]] | |
| for i in range(n_down): | |
| key = f"downs.{i}.conv1.weight" | |
| down_channels.append(sd.get(key, torch.zeros(down_channels[-1]*2,1,1,1)).shape[0]) | |
| return dict(image_channels=image_channels, | |
| down_channels=tuple(down_channels), | |
| time_emb_dim=time_emb_dim) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # DIFFUSION SCHEDULE (pre-compute everything once) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| T = 300 | |
| betas = torch.linspace(0.0001, 0.02, T) | |
| alphas = 1.0 - betas | |
| alphas_cumprod = torch.cumprod(alphas, dim=0) | |
| alphas_cumprod_prev = torch.cat([torch.tensor([1.0]), alphas_cumprod[:-1]]) | |
| sqrt_alphas_cumprod = alphas_cumprod.sqrt() | |
| sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod).sqrt() | |
| sqrt_recip_alphas = (1.0 / alphas).sqrt() | |
| # posterior variance q(x_{t-1}|x_t, x_0) | |
| posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod) | |
| posterior_variance = posterior_variance.clamp(min=1e-20) | |
| posterior_log_variance = posterior_variance.log() | |
| posterior_mean_coef1 = betas * alphas_cumprod_prev.sqrt() / (1.0 - alphas_cumprod) | |
| posterior_mean_coef2 = (1.0 - alphas_cumprod_prev) * alphas.sqrt() / (1.0 - alphas_cumprod) | |
| def _g(vals, t, x_shape, device): | |
| """Gather scalar schedule value for batch index t, broadcast to x_shape.""" | |
| out = vals.gather(-1, t.cpu()).to(device) | |
| return out.reshape(t.shape[0], *((1,) * (len(x_shape) - 1))) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # LOAD MODEL | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| MODEL_PATH = "ddpm_model.pth" | |
| model = None | |
| load_error = "" | |
| def load_model(): | |
| global model, load_error | |
| if not os.path.exists(MODEL_PATH): | |
| load_error = f"ddpm_model.pth not found at: {os.path.abspath(MODEL_PATH)}" | |
| return False | |
| try: | |
| raw = torch.load(MODEL_PATH, map_location="cpu", weights_only=False) | |
| sd = raw.get("model_state_dict", raw) if isinstance(raw, dict) else raw.state_dict() | |
| arch = detect_arch(sd) | |
| m = SimpleUnet(**arch).to(DEVICE) | |
| missing, unexpected = m.load_state_dict(sd, strict=False) | |
| m.eval() | |
| model = m | |
| load_error = f"Loaded β (missing={len(missing)}, unexpected={len(unexpected)})" if (missing or unexpected) else "" | |
| return True | |
| except Exception as e: | |
| load_error = f"Load error: {e}" | |
| return False | |
| model_loaded = load_model() | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # SAMPLERS | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _predict_x0(x_t, t_val): | |
| """Predict clean x0 from noisy x_t at timestep t.""" | |
| t = torch.full((1,), t_val, device=DEVICE, dtype=torch.long) | |
| sac = _g(sqrt_alphas_cumprod, t, x_t.shape, DEVICE) | |
| somac = _g(sqrt_one_minus_alphas_cumprod, t, x_t.shape, DEVICE) | |
| eps = model(x_t, t) | |
| x0_hat = (x_t - somac * eps) / sac # Tweedie / rearranged forward eq. | |
| return x0_hat.clamp(-1.0, 1.0), eps | |
| def ddpm_step(x_t, t_val): | |
| """ | |
| Correct DDPM reverse step: q(x_{t-1} | x_t, x_0_hat) | |
| β NO skipping, runs every single timestep. | |
| """ | |
| t = torch.full((1,), t_val, device=DEVICE, dtype=torch.long) | |
| x0, _ = _predict_x0(x_t, t_val) | |
| c1 = _g(posterior_mean_coef1, t, x_t.shape, DEVICE) | |
| c2 = _g(posterior_mean_coef2, t, x_t.shape, DEVICE) | |
| mean = c1 * x0 + c2 * x_t | |
| if t_val == 0: | |
| return mean | |
| log_var = _g(posterior_log_variance, t, x_t.shape, DEVICE) | |
| noise = torch.randn_like(x_t) | |
| return mean + (0.5 * log_var).exp() * noise | |
| def ddim_step(x_t, t_val, t_prev, eta=0.0): | |
| """ | |
| DDIM deterministic step (eta=0) or stochastic (eta>0). | |
| Allows large timestep skips while maintaining quality. | |
| """ | |
| t = torch.full((1,), t_val, device=DEVICE, dtype=torch.long) | |
| ac_t = _g(alphas_cumprod, t, x_t.shape, DEVICE) # tensor | |
| somac_t = _g(sqrt_one_minus_alphas_cumprod, t, x_t.shape, DEVICE) | |
| # ac_prev as a tensor broadcastable to x_t | |
| ac_prev_val = alphas_cumprod[t_prev].item() if t_prev >= 0 else 1.0 | |
| ac_prev = torch.tensor(ac_prev_val, device=DEVICE, dtype=ac_t.dtype) | |
| eps = model(x_t, t) | |
| x0_hat = ((x_t - somac_t * eps) / ac_t.sqrt()).clamp(-1.0, 1.0) | |
| if t_prev >= 0 and eta > 0.0: | |
| sigma = eta * ((1.0 - ac_prev) / (1.0 - ac_t) * (1.0 - ac_t / ac_prev)).clamp(min=0).sqrt() | |
| else: | |
| sigma = torch.zeros(1, device=DEVICE) | |
| dir_xt = (1.0 - ac_prev - sigma**2).clamp(min=0.0).sqrt() * eps | |
| noise = sigma * torch.randn_like(x_t) | |
| return ac_prev.sqrt() * x0_hat + dir_xt + noise | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # FORWARD NOISE | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def add_noise(tensor, t_val): | |
| """q(x_t | x_0) β closed-form forward process.""" | |
| t = torch.full((1,), t_val, device=DEVICE, dtype=torch.long) | |
| sac = _g(sqrt_alphas_cumprod, t, tensor.shape, DEVICE) | |
| somac = _g(sqrt_one_minus_alphas_cumprod, t, tensor.shape, DEVICE) | |
| noise = torch.randn_like(tensor) | |
| return (sac * tensor + somac * noise).clamp(-1.0, 1.0), noise | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # IMAGE HELPERS | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def preprocess(pil_img, sz, brightness, contrast, blur, sharpen, invert, equalize): | |
| img = pil_img.convert("L").resize((sz, sz), Image.LANCZOS) | |
| if brightness != 1.0: | |
| img = ImageEnhance.Brightness(img).enhance(brightness) | |
| if contrast != 1.0: | |
| img = ImageEnhance.Contrast(img).enhance(contrast) | |
| if blur > 0: | |
| img = img.filter(ImageFilter.GaussianBlur(radius=blur)) | |
| if sharpen > 0: | |
| img = img.filter(ImageFilter.UnsharpMask(radius=2, percent=int(sharpen*150), threshold=3)) | |
| if invert: | |
| img = ImageOps.invert(img) | |
| if equalize: | |
| img = ImageOps.equalize(img) | |
| return img | |
| def to_tensor(img): | |
| arr = np.array(img).astype(np.float32) / 255.0 | |
| return (torch.from_numpy(arr).unsqueeze(0).unsqueeze(0) * 2 - 1).to(DEVICE) | |
| def to_pil(t): | |
| arr = ((t[0, 0].cpu().float().numpy() + 1) / 2 * 255).clip(0, 255).astype(np.uint8) | |
| return Image.fromarray(arr, "L").convert("RGB") | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # GENERATE FROM NOISE | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def generate_image(num_steps, image_size, seed, snap_count, sampler, eta): | |
| if not model_loaded: | |
| return None, [], f"β οΈ {load_error}" | |
| torch.manual_seed(int(seed)) | |
| sz = int(image_size) | |
| x = torch.randn((1, 1, sz, sz), device=DEVICE) | |
| steps = min(int(num_steps), T) | |
| if sampler == "DDIM": | |
| seq = list(range(0, T, max(1, T // steps)))[::-1] | |
| seq_prev = [-1] + seq[:-1] | |
| snaps, every = [], max(1, len(seq) // int(snap_count)) | |
| for idx, (t_cur, t_prev) in enumerate(zip(seq, seq_prev)): | |
| x = ddim_step(x, t_cur, t_prev, eta=float(eta)) | |
| if idx % every == 0 or idx == len(seq)-1: | |
| snaps.append(to_pil(x.clamp(-1,1))) | |
| else: # DDPM β every step | |
| seq = list(reversed(range(steps))) | |
| snaps, every = [], max(1, steps // int(snap_count)) | |
| for idx, t_val in enumerate(seq): | |
| x = ddpm_step(x, t_val) | |
| if idx % every == 0 or idx == len(seq)-1: | |
| snaps.append(to_pil(x.clamp(-1,1))) | |
| return to_pil(x.clamp(-1,1)), snaps, f"β Done ({sampler}, {steps} steps)" | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # DENOISE UPLOADED IMAGE β FIXED | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def denoise_image(uploaded, noise_level, num_steps, seed, sampler, eta, | |
| img_size, brightness, contrast, blur, sharpen, invert, equalize): | |
| if not model_loaded: | |
| return None, None, None, None, f"β οΈ {load_error}" | |
| if uploaded is None: | |
| return None, None, None, None, "β οΈ Please upload an image." | |
| torch.manual_seed(int(seed)) | |
| sz = int(img_size) | |
| prc = preprocess(Image.fromarray(uploaded), sz, | |
| brightness, contrast, blur, sharpen, invert, equalize) | |
| pre_rgb = prc.convert("RGB") | |
| x0 = to_tensor(prc) # clean image in [-1,1] | |
| # ββ correct forward process ββββββββββββββββββββββββββ | |
| t_val = max(1, min(T - 1, int(float(noise_level) * (T - 1)))) | |
| x_noisy, _ = add_noise(x0, t_val) | |
| noisy_pil = to_pil(x_noisy) | |
| # ββ reconstruction βββββββββββββββββββββββββββββββββββ | |
| x = x_noisy.clone() | |
| if sampler == "DDIM": | |
| # Build a sub-sequence from t_val β 0 with num_steps steps | |
| n = min(int(num_steps), t_val + 1) | |
| seq = list(range(0, t_val + 1, max(1, (t_val + 1) // n))) | |
| if seq[-1] != t_val: | |
| seq.append(t_val) | |
| seq = seq[::-1] # high β low | |
| seq_prev = seq[1:] + [-1] # shifted by one | |
| for t_cur, t_prev in zip(seq, seq_prev): | |
| x = ddim_step(x, t_cur, t_prev, eta=float(eta)) | |
| else: | |
| # DDPM: must step EVERY timestep from t_val down to 0 β no skipping | |
| for t_val_i in range(t_val, -1, -1): | |
| x = ddpm_step(x, t_val_i) | |
| recon_pil = to_pil(x.clamp(-1, 1)) | |
| # ββ predicted x0 directly (fast single-step estimate) ββ | |
| x0_direct, _ = _predict_x0(x_noisy, t_val) | |
| direct_pil = to_pil(x0_direct) | |
| return (pre_rgb, noisy_pil, direct_pil, recon_pil, | |
| f"β Done ({sampler}, t={t_val}, steps={'all' if sampler=='DDPM' else num_steps})") | |
| def show_artifact(which): | |
| m = {"π Loss Curve": "loss_plot.png", | |
| "πΌοΈ Reconstruction": "reconstruction.png", | |
| "π Reverse Steps": "reverse_steps.png"} | |
| path = m.get(which, "") | |
| if path and os.path.exists(path): | |
| return Image.open(path), f"β {path}" | |
| return None, f"β οΈ Not found: {path}" | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # CSS | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| CSS = """ | |
| @import url('https://fonts.googleapis.com/css2?family=Space+Mono:wght@400;700&family=Syne:wght@400;600;700;800&display=swap'); | |
| :root { | |
| --bg:#07071a; --card:#0c0c22; --inp:#10102a; | |
| --acc:#a78bfa; --grn:#34d399; --cyn:#38bdf8; --glow:#7c3aed; | |
| --txt:#e2e8f0; --mut:#64748b; --bdr:rgba(167,139,250,.14); --r:12px; | |
| } | |
| body,.gradio-container{background:var(--bg)!important;font-family:'Syne',sans-serif!important;color:var(--txt)!important;} | |
| /* hero */ | |
| #hero{background:linear-gradient(160deg,#0b0629,#120b36,#071528);border:1px solid var(--bdr);border-radius:var(--r);padding:2.6rem 2rem 2rem;text-align:center;position:relative;overflow:hidden;margin-bottom:1.4rem;} | |
| #hero::before{content:'';position:absolute;inset:0;background:radial-gradient(ellipse 70% 55% at 50% 0%,rgba(124,58,237,.3),transparent 68%);} | |
| #hero::after{content:'';position:absolute;inset:0;background:url("data:image/svg+xml,%3Csvg width='60' height='60' viewBox='0 0 60 60' xmlns='http://www.w3.org/2000/svg'%3E%3Cg fill='%239C92AC' fill-opacity='0.03'%3E%3Cpath d='M36 34v-4h-2v4h-4v2h4v4h2v-4h4v-2h-4zm0-30V0h-2v4h-4v2h4v4h2V6h4V4h-4zM6 34v-4H4v4H0v2h4v4h2v-4h4v-2H6zM6 4V0H4v4H0v2h4v4h2V6h4V4H6z'/%3E%3C/g%3E%3C/svg%3E");} | |
| .htitle{font-size:2.9rem;font-weight:800;letter-spacing:-1.5px;background:linear-gradient(100deg,var(--acc),var(--cyn) 55%,var(--grn));-webkit-background-clip:text;-webkit-text-fill-color:transparent;margin:0 0 .3rem;position:relative;} | |
| .hsub{font-family:'Space Mono',monospace;font-size:.87rem;color:var(--mut);margin:0;position:relative;} | |
| .badge{display:inline-flex;align-items:center;gap:.4rem;margin-top:.8rem;padding:.28rem 1rem;border-radius:999px;font-family:'Space Mono',monospace;font-size:.74rem;font-weight:700;border:1px solid rgba(167,139,250,.3);background:rgba(167,139,250,.08);color:var(--acc);position:relative;} | |
| .badge.ok{border-color:rgba(52,211,153,.35);background:rgba(52,211,153,.08);color:var(--grn);} | |
| .badge.err{border-color:rgba(248,113,113,.35);background:rgba(248,113,113,.08);color:#f87171;} | |
| /* tabs */ | |
| .tab-nav{border-bottom:1px solid var(--bdr)!important;} | |
| .tab-nav button{font-family:'Syne',sans-serif!important;font-weight:600!important;font-size:.88rem!important;color:var(--mut)!important;background:transparent!important;border:none!important;border-bottom:2px solid transparent!important;padding:.7rem 1.4rem!important;transition:all .2s!important;} | |
| .tab-nav button.selected{color:var(--acc)!important;border-bottom-color:var(--acc)!important;} | |
| /* inputs */ | |
| input[type=number],input[type=text],textarea{background:var(--inp)!important;border:1px solid var(--bdr)!important;color:var(--txt)!important;border-radius:8px!important;font-family:'Space Mono',monospace!important;font-size:.82rem!important;} | |
| input[type=range]{accent-color:var(--acc)!important;} | |
| label span,.label-wrap span{font-family:'Space Mono',monospace!important;font-size:.72rem!important;color:var(--mut)!important;text-transform:uppercase;letter-spacing:.07em;} | |
| .gr-check-radio{accent-color:var(--acc)!important;} | |
| /* buttons */ | |
| button.primary,.gr-button-primary{background:linear-gradient(135deg,var(--glow),#0e7490)!important;color:#fff!important;border:none!important;border-radius:9px!important;font-family:'Syne',sans-serif!important;font-weight:700!important;font-size:.94rem!important;padding:.65rem 2rem!important;box-shadow:0 0 20px rgba(124,58,237,.4)!important;transition:all .2s!important;width:100%;} | |
| button.primary:hover{box-shadow:0 0 36px rgba(124,58,237,.65)!important;transform:translateY(-2px)!important;} | |
| /* section label */ | |
| .sl{font-family:'Space Mono',monospace;font-size:.7rem;font-weight:700;letter-spacing:.1em;text-transform:uppercase;color:var(--mut);padding-bottom:.4rem;border-bottom:1px solid var(--bdr);margin-bottom:.7rem;} | |
| /* tip box */ | |
| .tip{background:rgba(56,189,248,.06);border:1px solid rgba(56,189,248,.18);border-radius:8px;padding:.7rem 1rem;font-family:'Space Mono',monospace;font-size:.78rem;color:var(--cyn);margin-top:.5rem;} | |
| /* images */ | |
| .gr-image img,.output-image img{border-radius:10px!important;border:1px solid var(--bdr)!important;} | |
| .gallery-item{border-radius:8px!important;} | |
| .gr-textbox textarea{font-family:'Space Mono',monospace!important;font-size:.8rem!important;} | |
| #footer{text-align:center;padding:1.2rem;font-family:'Space Mono',monospace;font-size:.74rem;color:var(--mut);margin-top:1rem;border-top:1px solid var(--bdr);} | |
| """ | |
| bc = "ok" if model_loaded else "err" | |
| bi = "π’" if model_loaded else "π΄" | |
| bm = "Model Loaded Β· DDPM Ready" if model_loaded else f"Model Not Found β place ddpm_model.pth next to app.py" | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # UI | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| with gr.Blocks(title="Noise2Vision β DDPM") as demo: | |
| gr.HTML(f""" | |
| <div id="hero"> | |
| <div class="htitle">β‘ Noise2Vision</div> | |
| <p class="hsub">Denoising Diffusion Probabilistic Model Β· Reverse the noise, reveal the signal</p> | |
| <div class="badge {bc}">{bi} {bm}</div> | |
| </div> | |
| """) | |
| with gr.Tabs(): | |
| # ββ GENERATE ββββββββββββββββββββββββββββββββββββββββββββββ | |
| with gr.Tab("π² Generate"): | |
| gr.Markdown("#### Generate a new image from pure Gaussian noise via reverse diffusion.") | |
| with gr.Row(equal_height=False): | |
| with gr.Column(scale=1, min_width=270): | |
| gr.HTML('<div class="sl">βοΈ Diffusion Controls</div>') | |
| g_sampler = gr.Radio(["DDPM","DDIM"], value="DDIM", label="Sampler") | |
| g_steps = gr.Slider(10, 300, value=100, step=10, label="Steps") | |
| g_size = gr.Slider(32, 128, value=64, step=32, label="Output Size (px)") | |
| g_seed = gr.Number(value=42, label="Random Seed", precision=0) | |
| g_snap = gr.Slider(4, 16, value=8, step=2, label="Snapshots") | |
| g_eta = gr.Slider(0.0, 1.0, value=0.0, step=0.1, label="DDIM Ξ· (0=deterministic, 1=DDPM-like)") | |
| gen_btn = gr.Button("β¦ Generate", variant="primary") | |
| gen_status = gr.Textbox(label="Status", interactive=False, lines=2) | |
| with gr.Column(scale=2): | |
| gr.HTML('<div class="sl">πΌοΈ Result</div>') | |
| gen_out = gr.Image(label="Generated", type="pil", height=300) | |
| gr.HTML('<div class="sl" style="margin-top:.8rem">ποΈ Snapshots</div>') | |
| gen_gallery = gr.Gallery(label="", columns=8, height=120, allow_preview=True) | |
| gen_btn.click(generate_image, | |
| [g_steps, g_size, g_seed, g_snap, g_sampler, g_eta], | |
| [gen_out, gen_gallery, gen_status]) | |
| # ββ DENOISE ββββββββββββββββββββββββββββββββββββββββββββββββ | |
| with gr.Tab("π¬ Denoise Upload"): | |
| gr.Markdown("#### Upload β preprocess β add noise β reconstruct. Four-stage pipeline.") | |
| with gr.Row(equal_height=False): | |
| with gr.Column(scale=1, min_width=295): | |
| gr.HTML('<div class="sl">π Image</div>') | |
| up_img = gr.Image(label="Upload Image", type="numpy", height=175) | |
| gr.HTML('<div class="sl" style="margin-top:.9rem">π¨ Preprocessing</div>') | |
| img_size = gr.Slider(32, 128, value=64, step=32, label="Resize (px)") | |
| with gr.Row(): | |
| brightness = gr.Slider(0.3, 2.5, value=1.0, step=0.1, label="Brightness") | |
| contrast = gr.Slider(0.3, 2.5, value=1.0, step=0.1, label="Contrast") | |
| with gr.Row(): | |
| blur_s = gr.Slider(0.0, 5.0, value=0.0, step=0.5, label="Blur") | |
| sharpen = gr.Slider(0.0, 1.0, value=0.0, step=0.1, label="Sharpen") | |
| with gr.Row(): | |
| invert_c = gr.Checkbox(label="π Invert", value=False) | |
| equalize_c = gr.Checkbox(label="π Equalize", value=False) | |
| gr.HTML('<div class="sl" style="margin-top:.9rem">βοΈ Diffusion</div>') | |
| d_sampler = gr.Radio(["DDPM","DDIM"], value="DDIM", label="Sampler") | |
| noise_lvl = gr.Slider(0.05, 0.99, value=0.5, step=0.01, label="Noise Level (t/T)") | |
| den_steps = gr.Slider(10, 200, value=50, step=5, label="DDIM Steps (ignored for DDPM)") | |
| den_eta = gr.Slider(0.0, 1.0, value=0.0, step=0.1, label="DDIM Ξ·") | |
| den_seed = gr.Number(value=0, label="Seed", precision=0) | |
| gr.HTML('<div class="tip">π‘ <b>Tip:</b> DDIM + 50 steps gives sharp, fast reconstruction.<br>DDPM runs all tβ0 steps (slow but theoretically exact).</div>') | |
| den_btn = gr.Button("β¦ Reconstruct", variant="primary") | |
| den_status = gr.Textbox(label="Status", interactive=False, lines=2) | |
| with gr.Column(scale=2): | |
| gr.HTML('<div class="sl">π 4-Stage Pipeline</div>') | |
| with gr.Row(): | |
| pre_out = gr.Image(label="β Preprocessed", type="pil", height=200) | |
| noisy_out = gr.Image(label="β‘ Noisy (t)", type="pil", height=200) | |
| with gr.Row(): | |
| direct_out = gr.Image(label="β’ Direct xβ Estimate (1-step)", type="pil", height=200) | |
| recon_out = gr.Image(label="β£ Full Reconstruction", type="pil", height=200) | |
| gr.HTML('<div class="tip">β’ is a fast single-step prediction. β£ is the full iterative reverse result.</div>') | |
| den_btn.click( | |
| denoise_image, | |
| [up_img, noise_lvl, den_steps, den_seed, d_sampler, den_eta, | |
| img_size, brightness, contrast, blur_s, sharpen, invert_c, equalize_c], | |
| [pre_out, noisy_out, direct_out, recon_out, den_status] | |
| ) | |
| # ββ ARTIFACTS ββββββββββββββββββββββββββββββββββββββββββββββ | |
| with gr.Tab("π Training Artifacts"): | |
| gr.Markdown("#### View saved training outputs.") | |
| with gr.Row(): | |
| art_radio = gr.Radio( | |
| ["π Loss Curve","πΌοΈ Reconstruction","π Reverse Steps"], | |
| value="π Loss Curve", label="Select" | |
| ) | |
| view_btn = gr.Button("View β", variant="primary", scale=0) | |
| art_status = gr.Textbox(label="", interactive=False, lines=1) | |
| art_out = gr.Image(label="", type="pil", height=460) | |
| view_btn.click(show_artifact, [art_radio], [art_out, art_status]) | |
| # ββ ABOUT ββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| with gr.Tab("βΉοΈ About"): | |
| gr.Markdown(f""" | |
| ## Noise2Vision β DDPM | |
| ### What changed in reconstruction | |
| | Old (broken) | New (fixed) | | |
| |---|---| | |
| | Skipped timesteps with `stride` | DDPM runs **every** step tβ0 | | |
| | Wrong posterior: used `betas * pred / sqrt_omac` | Correct `q(x_{{t-1}}|x_t,xΜ_0)` posterior mean | | |
| | Clamped intermediate latents | Only clamp final output | | |
| | No DDIM | **DDIM** added β 50 steps β quality of 300 DDPM steps | | |
| | Single output | **4-stage** output: preprocessed β noisy β direct xΜβ β iterative recon | | |
| ### Architecture (auto-detected) | |
| | Component | Detail | | |
| |---|---| | |
| | Backbone | U-Net + sinusoidal time embeddings | | |
| | Encoder | 64β128β256β512β1024 | | |
| | Decoder | 1024β512β256β128β64 | | |
| | T | 300 Β· Linear Ξ² schedule 0.0001β0.02 | | |
| ### Files | |
| `ddpm_model.pth` Β· `state.db` Β· `loss_plot.png` Β· `reconstruction.png` Β· `reverse_steps.png` | |
| Model path: `{os.path.abspath(MODEL_PATH)}` β {"β Found" if os.path.exists(MODEL_PATH) else "β Not found"} | |
| --- | |
| *Noise2Vision Β· AsadAnalyst Β· Hugging Face Spaces* | |
| """) | |
| gr.HTML('<div id="footer">Noise2Vision Β· DDPM + DDIM Β· Gradio Β· AsadAnalyst</div>') | |
| if __name__ == "__main__": | |
| demo.launch(css=CSS, theme=gr.themes.Base()) |