| import streamlit as st |
| from PIL import Image, ImageOps |
| import torch |
| from matplotlib.image import imread |
| import numpy as np |
| import tensorflow as tf |
| import math |
| import torch.nn.functional as F |
| from tqdm.auto import tqdm |
| from torchvision import transforms |
| import matplotlib.pyplot as plt |
|
|
| from torch import nn |
| img_size = 64 |
| BATCH_SIZE = 64 |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
| class Block(nn.Module): |
| def __init__(self, in_ch, out_ch, time_emb_dim, up=False): |
| super().__init__() |
| self.time_mlp = nn.Linear(time_emb_dim, out_ch) |
| if up: |
| self.conv1 = nn.Conv2d(2*in_ch, out_ch, 3, padding=1) |
| self.transform = nn.ConvTranspose2d(out_ch, out_ch, 4, 2, 1) |
| self.Upsample = nn.Upsample(scale_factor = 2, mode ='bilinear') |
|
|
| else: |
| self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1) |
| self.transform = nn.Conv2d(out_ch, out_ch, 4, 2, 1) |
| self.maxpool = nn.MaxPool2d(4, 2, 1) |
| self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1) |
| self.bnorm1 = nn.BatchNorm2d(out_ch) |
| self.bnorm2 = nn.BatchNorm2d(out_ch) |
| self.silu = nn.SiLU() |
| self.relu = nn.ReLU() |
|
|
| def forward(self, x, t, ): |
| |
| h = (self.silu(self.bnorm1(self.conv1(x)))) |
| |
| time_emb = self.relu(self.time_mlp(t)) |
| |
| time_emb = time_emb[(..., ) + (None, ) * 2] |
| |
| h = h + time_emb |
| |
| h = (self.silu(self.bnorm2(self.conv2(h)))) |
| |
| return self.transform(h) |
|
|
|
|
| class SinusoidalPositionEmbeddings(nn.Module): |
| def __init__(self, dim): |
| super().__init__() |
| self.dim = dim |
|
|
| def forward(self, time): |
| device = time.device |
| half_dim = self.dim // 2 |
| embeddings = math.log(10000) / (half_dim - 1) |
| embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings) |
| embeddings = time[:, None] * embeddings[None, :] |
| embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1) |
| |
| return embeddings |
|
|
|
|
| class SimpleUnet(nn.Module): |
| """ |
| A simplified variant of the Unet architecture. |
| """ |
| def __init__(self): |
| super().__init__() |
| image_channels = 3 |
| down_channels = (32, 64, 128, 256, 512) |
| up_channels = (512, 256, 128, 64, 32) |
| out_dim = 3 |
| time_emb_dim = 32 |
|
|
| |
| self.time_mlp = nn.Sequential( |
| SinusoidalPositionEmbeddings(time_emb_dim), |
| nn.Linear(time_emb_dim, time_emb_dim), |
| nn.ReLU() |
| ) |
|
|
| |
| self.conv0 = nn.Conv2d(image_channels, down_channels[0], 3, padding=1) |
|
|
| |
| self.downs = nn.ModuleList([Block(down_channels[i], down_channels[i+1], \ |
| time_emb_dim) \ |
| for i in range(len(down_channels)-1)]) |
| |
| self.ups = nn.ModuleList([Block(up_channels[i], up_channels[i+1], \ |
| time_emb_dim, up=True) \ |
| for i in range(len(up_channels)-1)]) |
|
|
| |
| self.output = nn.Conv2d(up_channels[-1], out_dim, 1) |
|
|
| def forward(self, x, timestep): |
| |
| t = self.time_mlp(timestep) |
| |
| x = self.conv0(x) |
| |
| residual_inputs = [] |
| for down in self.downs: |
| x = down(x, t) |
| residual_inputs.append(x) |
| for up in self.ups: |
| residual_x = residual_inputs.pop() |
| |
| x = torch.cat((x, residual_x), dim=1) |
| x = up(x, t) |
| return self.output(x) |
|
|
| model = SimpleUnet() |
|
|
|
|
| def linear_beta_schedule(timesteps): |
| beta_start = 0.0001 |
| beta_end = 0.02 |
| return torch.linspace(beta_start, beta_end, timesteps) |
|
|
| timesteps= 300 |
| betas = linear_beta_schedule(timesteps=timesteps) |
|
|
| alphas = 1. - betas |
| alphas_cumprod = torch.cumprod(alphas, axis=0) |
| alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0) |
| sqrt_recip_alphas = torch.sqrt(1.0 / alphas) |
| sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod) |
| sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod) |
| posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) |
|
|
|
|
| def extract(a, t, x_shape): |
| batch_size = t.shape[0] |
| out = a.gather(-1, t.cpu()) |
| return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device) |
|
|
| @torch.no_grad() |
| def p_sample(model, x, t, t_index): |
| betas_t = extract(betas, t, x.shape) |
| sqrt_one_minus_alphas_cumprod_t = extract( |
| sqrt_one_minus_alphas_cumprod, t, x.shape |
| ) |
| sqrt_recip_alphas_t = extract(sqrt_recip_alphas, t, x.shape) |
|
|
| |
| |
| model_mean = sqrt_recip_alphas_t * ( |
| x - betas_t * model(x, t) / sqrt_one_minus_alphas_cumprod_t |
| ) |
|
|
| if t_index == 0: |
| return model_mean |
| else: |
| posterior_variance_t = extract(posterior_variance, t, x.shape) |
| noise = torch.randn_like(x) |
| |
| return model_mean + torch.sqrt(posterior_variance_t) * noise |
|
|
| |
| @torch.no_grad() |
| def p_sample_loop(model, shape): |
| device = next(model.parameters()).device |
|
|
| b = shape[0] |
| |
| img = torch.randn(shape, device=device) |
| imgs = [] |
|
|
| for i in tqdm(reversed(range(0, timesteps)), desc='sampling loop time step', total=1): |
| img = p_sample(model, img, torch.full((b,), i, device=device, dtype=torch.long), 3) |
| imgs.append(img.cpu().numpy()) |
| return imgs |
|
|
| @torch.no_grad() |
| def sample(model, image_size, batch_size=16, channels=3): |
| return p_sample_loop(model, shape=(batch_size, channels, image_size, image_size)) |
|
|
|
|
|
|
| model = SimpleUnet() |
|
|
| st.title("Generatig images using a diffusion model") |
| model.load_state_dict(torch.load("new_linear_model_1090.pt", map_location=torch.device('cpu'))) |
|
|
|
|
| if(st.button("Click to generate image")): |
| samples = sample(model, image_size=img_size, batch_size=64, channels=3) |
| for i in range(1): |
| reverse_transforms = transforms.Compose([ |
| transforms.Lambda(lambda t: (t + 1) / 2), |
| transforms.Lambda(lambda t: t.permute(1, 2, 0)), |
| transforms.Lambda(lambda t: t * 255.), |
| transforms.Lambda(lambda t: t.numpy().astype(np.uint8)), |
| transforms.ToPILImage(), |
| ]) |
| img = reverse_transforms(torch.Tensor((samples[-1][i].reshape(3, img_size, img_size)))) |
| |
| st.image(plt.imshow(img)) |