|
|
import gradio as gr |
|
|
import torch |
|
|
from torch import nn |
|
|
import torchvision |
|
|
from torchvision.transforms import ToTensor |
|
|
from types import SimpleNamespace |
|
|
import matplotlib.pyplot as plt |
|
|
|
|
|
|
|
|
unet = torch.load("unet_checkpoint13.pt", map_location=torch.device('cpu')).to("cpu") |
|
|
unet.eval() |
|
|
|
|
|
@torch.no_grad |
|
|
def linear_sched(betamin=0.0001,betamax=0.02,n_steps=1000): |
|
|
beta = torch.linspace(betamin, betamax, n_steps) |
|
|
return SimpleNamespace(a=1.-beta, abar=(1.-beta).cumprod(dim=0), sig=beta.sqrt()) |
|
|
n_steps = 1000 |
|
|
lin_abar = linear_sched(betamax=0.01) |
|
|
alphabar = lin_abar.abar |
|
|
alpha = lin_abar.a |
|
|
sigma = lin_abar.sig |
|
|
|
|
|
@torch.no_grad() |
|
|
def generate(): |
|
|
model = unet |
|
|
sz = (1, 1, 32, 32) |
|
|
|
|
|
ps = next(model.parameters()) |
|
|
x_t = torch.randn(sz).to(ps) |
|
|
sample_at = {t for t in range(n_steps) if (t+101)%((t+101)//100)==0} |
|
|
preds = [] |
|
|
img_final = img_799 = img_399 = (x_t[0].float().cpu()+0.5).squeeze().clamp(-1,1).detach().numpy() |
|
|
for t in reversed(range(n_steps)): |
|
|
t_batch = torch.full((x_t.shape[0],), t, device=ps.device, dtype=torch.long) |
|
|
z = (torch.randn(x_t.shape) if t > 0 else torch.zeros(x_t.shape)).to(ps) |
|
|
ᾱ_t1 = alphabar[t-1] if t > 0 else torch.tensor(1) |
|
|
b̄_t = 1-alphabar[t] |
|
|
b̄_t1 = 1-ᾱ_t1 |
|
|
if t in sample_at: noise = model(x_t, t_batch).sample |
|
|
x_0_hat = ((x_t - b̄_t.sqrt() * noise)/alphabar[t].sqrt()) |
|
|
x_t = x_0_hat * ᾱ_t1.sqrt()*(1-alpha[t])/b̄_t + x_t * alpha[t].sqrt()*b̄_t1/b̄_t + sigma[t]*z |
|
|
if t in sample_at: |
|
|
preds.append(x_t.float().cpu()) |
|
|
img = (x_t[0].float().cpu()+0.5).squeeze().clamp(-1,1).detach().numpy() |
|
|
|
|
|
if t >= 799: |
|
|
img_final = img_799 = img_399 = img |
|
|
elif t >= 50: |
|
|
img_final = img_399 = img |
|
|
else: |
|
|
img_final = img |
|
|
|
|
|
yield(img_799,img_399,img_final) |
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
gr.HTML("""<h1 align="center">UNet with DPPM</h1>""") |
|
|
gr.HTML("""<h1 align="center">trained with FashionMNIST</h1>""") |
|
|
session_data = gr.State([]) |
|
|
|
|
|
sampling_button = gr.Button("Unconditional image generation") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=2): |
|
|
gr.HTML("""<h3 align="left">image at step 800</h1>""") |
|
|
step_800_image = gr.Image(height=250,width=200) |
|
|
with gr.Column(scale=2): |
|
|
gr.HTML("""<h3 align="left">image at step 50</h1>""") |
|
|
step_50_image = gr.Image(height=250,width=200) |
|
|
with gr.Column(scale=2): |
|
|
gr.HTML("""<h3 align="left">final image</h1>""") |
|
|
step_final_image = gr.Image(height=250,width=200) |
|
|
|
|
|
|
|
|
sampling_button.click( |
|
|
generate, |
|
|
[], |
|
|
[step_800_image, step_50_image, step_final_image], |
|
|
) |
|
|
|
|
|
demo.queue().launch(share=False, inbrowser=True) |
|
|
|