import streamlit as st from PIL import Image, ImageOps import torch from matplotlib.image import imread import numpy as np import tensorflow as tf import math class Block(nn.Module): def __init__(self, in_ch, out_ch, time_emb_dim, up=False): super().__init__() self.time_mlp = nn.Linear(time_emb_dim, out_ch) if up: self.conv1 = nn.Conv2d(2*in_ch, out_ch, 3, padding=1) self.transform = nn.ConvTranspose2d(out_ch, out_ch, 4, 2, 1) self.Upsample = nn.Upsample(scale_factor = 2, mode ='bilinear') else: self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1) self.transform = nn.Conv2d(out_ch, out_ch, 4, 2, 1) self.maxpool = nn.MaxPool2d(4, 2, 1) self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1) self.bnorm1 = nn.BatchNorm2d(out_ch) self.bnorm2 = nn.BatchNorm2d(out_ch) self.silu = nn.SiLU() self.relu = nn.ReLU() def forward(self, x, t, ): # First Conv h = (self.silu(self.bnorm1(self.conv1(x)))) # Time embedding time_emb = self.relu(self.time_mlp(t)) # Extend last 2 dimensions time_emb = time_emb[(..., ) + (None, ) * 2] # Add time channel h = h + time_emb # Second Conv h = (self.silu(self.bnorm2(self.conv2(h)))) # Down or Upsample return self.transform(h) class SinusoidalPositionEmbeddings(nn.Module): def __init__(self, dim): super().__init__() self.dim = dim def forward(self, time): device = time.device half_dim = self.dim // 2 embeddings = math.log(10000) / (half_dim - 1) embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings) embeddings = time[:, None] * embeddings[None, :] embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1) # TODO: Double check the ordering here return embeddings class SimpleUnet(nn.Module): """ A simplified variant of the Unet architecture. """ def __init__(self): super().__init__() image_channels = 3 down_channels = (32, 64, 128, 256, 512) up_channels = (512, 256, 128, 64, 32) out_dim = 3 time_emb_dim = 32 # Time embedding self.time_mlp = nn.Sequential( SinusoidalPositionEmbeddings(time_emb_dim), nn.Linear(time_emb_dim, time_emb_dim), nn.ReLU() ) # Initial projection self.conv0 = nn.Conv2d(image_channels, down_channels[0], 3, padding=1) # Downsample self.downs = nn.ModuleList([Block(down_channels[i], down_channels[i+1], \ time_emb_dim) \ for i in range(len(down_channels)-1)]) # Upsample self.ups = nn.ModuleList([Block(up_channels[i], up_channels[i+1], \ time_emb_dim, up=True) \ for i in range(len(up_channels)-1)]) # Edit: Corrected a bug found by Jakub C (see YouTube comment) self.output = nn.Conv2d(up_channels[-1], out_dim, 1) def forward(self, x, timestep): # Embedd time t = self.time_mlp(timestep) # Initial conv x = self.conv0(x) # Unet residual_inputs = [] for down in self.downs: x = down(x, t) residual_inputs.append(x) for up in self.ups: residual_x = residual_inputs.pop() # Add residual x as additional channels x = torch.cat((x, residual_x), dim=1) x = up(x, t) return self.output(x) model = SimpleUnet() st.title("Generatig images using a diffusion model") model.load_state_dict(torch.load("new_linear_model_1090.pt")) result = st.button("Click to generate image") if(result): model()