| import streamlit as st |
| from PIL import Image, ImageOps |
| import torch |
| from matplotlib.image import imread |
| import numpy as np |
| import tensorflow as tf |
| import math |
| import torch.nn.functional as F |
|
|
|
|
| def linear_beta_schedule(timesteps): |
| beta_start = 0.0001 |
| beta_end = 0.02 |
| return torch.linspace(beta_start, beta_end, timesteps) |
|
|
| timesteps= 300 |
| betas = linear_beta_schedule(timesteps=timesteps) |
|
|
| alphas = 1. - betas |
| alphas_cumprod = torch.cumprod(alphas, axis=0) |
| alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0) |
| sqrt_recip_alphas = torch.sqrt(1.0 / alphas) |
| sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod) |
| sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod) |
| posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) |
|
|
|
|
| def extract(a, t, x_shape): |
| batch_size = t.shape[0] |
| out = a.gather(-1, t.cpu()) |
| return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device) |
|
|
| @torch.no_grad() |
| def p_sample(model, x, t, t_index): |
| betas_t = extract(betas, t, x.shape) |
| sqrt_one_minus_alphas_cumprod_t = extract( |
| sqrt_one_minus_alphas_cumprod, t, x.shape |
| ) |
| sqrt_recip_alphas_t = extract(sqrt_recip_alphas, t, x.shape) |
|
|
| |
| |
| model_mean = sqrt_recip_alphas_t * ( |
| x - betas_t * model(x, t) / sqrt_one_minus_alphas_cumprod_t |
| ) |
|
|
| if t_index == 0: |
| return model_mean |
| else: |
| posterior_variance_t = extract(posterior_variance, t, x.shape) |
| noise = torch.randn_like(x) |
| |
| return model_mean + torch.sqrt(posterior_variance_t) * noise |
|
|
| |
| @torch.no_grad() |
| def p_sample_loop(model, shape): |
| device = next(model.parameters()).device |
|
|
| b = shape[0] |
| |
| img = torch.randn(shape, device=device) |
| imgs = [] |
|
|
| for i in tqdm(reversed(range(0, timesteps)), desc='sampling loop time step', total=timesteps): |
| img = p_sample(model, img, torch.full((b,), i, device=device, dtype=torch.long), 3) |
| imgs.append(img.cpu().numpy()) |
| return imgs |
|
|
| @torch.no_grad() |
| def sample(model, image_size, batch_size=16, channels=3): |
| return p_sample_loop(model, shape=(batch_size, channels, image_size, image_size)) |
|
|
|
|
|
|
| model = SimpleUnet() |
|
|
| st.title("Generatig images using a diffusion model") |
| model.load_state_dict(torch.load("new_linear_model_1090.pt")) |
|
|
|
|
| if(st.button("Click to generate image")): |
| samples = sample(model, image_size=img_size, batch_size=64, channels=3) |
| for i in range(10): |
| reverse_transforms = transforms.Compose([ |
| transforms.Lambda(lambda t: (t + 1) / 2), |
| transforms.Lambda(lambda t: t.permute(1, 2, 0)), |
| transforms.Lambda(lambda t: t * 255.), |
| transforms.Lambda(lambda t: t.numpy().astype(np.uint8)), |
| transforms.ToPILImage(), |
| ]) |
| img = reverse_transforms(torch.Tensor((samples[-1][i].reshape(3, img_size, img_size)))) |
| |
| st.image(plt.imshow(img)) |