|
|
import torch |
|
|
from torch import nn |
|
|
import torch.nn.functional as F |
|
|
import math |
|
|
from huggingface_hub import PyTorchModelHubMixin |
|
|
|
|
|
def timestep_embedding(tsteps, emb_dim, max_period= 10000): |
|
|
exponent = -math.log(max_period) * torch.linspace(0, 1, emb_dim//2, device=tsteps.device) |
|
|
emb = tsteps[:,None].float() * exponent.exp()[None,:] |
|
|
emb = torch.cat([emb.sin(), emb.cos()], dim=-1) |
|
|
return F.pad(emb, (0,1,0,0)) if emb_dim%2==1 else emb |
|
|
|
|
|
def lin(ni,nf,act=nn.SiLU,norm=None,bias=True): |
|
|
layers = nn.Sequential() |
|
|
if norm: |
|
|
layers.append(norm(ni)) |
|
|
if act: |
|
|
layers.append(act()) |
|
|
layers.append(nn.Linear(ni,nf,bias=bias)) |
|
|
return layers |
|
|
|
|
|
def unet_conv(ni,nf,act=nn.SiLU,norm=None,bias=True,ks=3,stride=1): |
|
|
layers = nn.Sequential() |
|
|
if norm: |
|
|
layers.append(norm(ni)) |
|
|
if act: |
|
|
layers.append(act()) |
|
|
layers.append(nn.Conv2d(ni,nf,kernel_size=ks,stride=stride,padding=ks//2,bias=bias)) |
|
|
return layers |
|
|
|
|
|
class EmbResBlock(nn.Module): |
|
|
def __init__(self, n_emb,ni,nf,act=nn.SiLU,norm=nn.BatchNorm2d,bias=True,ks=3): |
|
|
super().__init__() |
|
|
if nf is None: |
|
|
nf = ni |
|
|
self.emb_proj = nn.Linear(n_emb,nf*2) |
|
|
self.conv1 = unet_conv(ni,nf,norm=norm,act=act,ks=ks) |
|
|
self.conv2 = unet_conv(nf,nf,norm=norm,act=act,ks=ks) |
|
|
self.idconv = nn.Identity() if ni==nf else nn.Conv2d(ni,nf,1) |
|
|
|
|
|
def forward(self,x,t): |
|
|
inp = x |
|
|
emb = self.emb_proj(F.silu(t))[:, :, None, None] |
|
|
x = self.conv1(x) |
|
|
scale,shift = torch.chunk(emb,2,dim=1) |
|
|
x = x*(1+scale) + shift |
|
|
x= self.conv2(x) |
|
|
|
|
|
return x + self.idconv(inp) |
|
|
|
|
|
from functools import wraps |
|
|
|
|
|
def saved(m, blk): |
|
|
m_ = m.forward |
|
|
|
|
|
@wraps(m.forward) |
|
|
def _f(*args, **kwargs): |
|
|
res = m_(*args, **kwargs) |
|
|
blk.saved.append(res) |
|
|
return res |
|
|
|
|
|
m.forward = _f |
|
|
return m |
|
|
|
|
|
class DownBlock(nn.Module): |
|
|
def __init__(self, n_emb,ni,nf,add_down,num_layers=1): |
|
|
super().__init__() |
|
|
self.resnets =nn.ModuleList([saved(EmbResBlock(n_emb,ni if i ==0 else nf,nf),self) for i in range(num_layers)]) |
|
|
if add_down: |
|
|
self.down = saved(nn.Conv2d(nf, nf, 3, stride=2, padding=1),self) |
|
|
else: |
|
|
self.down = nn.Identity() |
|
|
|
|
|
def forward(self,x,t): |
|
|
self.saved = [] |
|
|
for resnet in self.resnets: |
|
|
x = resnet(x,t) |
|
|
x= self.down(x) |
|
|
|
|
|
return x |
|
|
|
|
|
def upsample(nf): return nn.Sequential(nn.Upsample(scale_factor=2.), nn.Conv2d(nf, nf, 3, padding=1)) |
|
|
|
|
|
class UpBlock(nn.Module): |
|
|
def __init__(self, n_emb,ni,prev_nf,nf,add_up,num_layers = 1): |
|
|
super().__init__() |
|
|
self.resnets = nn.ModuleList([EmbResBlock(n_emb,(prev_nf if i==0 else nf)+(ni if (i==num_layers-1) else nf),nf) for i in range(num_layers)]) |
|
|
self.up = upsample(nf) if add_up else nn.Identity() |
|
|
|
|
|
def forward(self,x,t,ups): |
|
|
for resnet in self.resnets: |
|
|
x= resnet(torch.cat([x,ups.pop()],dim=1),t) |
|
|
x = self.up(x) |
|
|
|
|
|
return x |
|
|
|
|
|
class EmbUnetModel(nn.Module, PyTorchModelHubMixin): |
|
|
def __init__(self, in_channels =1,out_channels=1,nfs=(32,64,128,256),n_layers=2): |
|
|
super().__init__() |
|
|
self.conv_in = nn.Conv2d(in_channels,nfs[0],kernel_size=3,padding=1) |
|
|
self.n_temb = nf = nfs[0] |
|
|
n_emb = nf*4 |
|
|
self.cond_emb = nn.Embedding(10, n_emb) |
|
|
self.mlp_emb = nn.Sequential(lin(self.n_temb,n_emb,norm=nn.BatchNorm1d),lin(n_emb,n_emb)) |
|
|
|
|
|
self.downs = nn.ModuleList() |
|
|
|
|
|
for i in range(len(nfs)): |
|
|
ni=nf |
|
|
nf= nfs[i] |
|
|
self.downs.append(DownBlock(n_emb,ni,nf,add_down=i!=len(nfs)-1,num_layers=n_layers)) |
|
|
|
|
|
self.mid_block = EmbResBlock(n_emb,ni=nfs[-1],nf=None) |
|
|
|
|
|
rev_nfs = list(reversed(nfs)) |
|
|
nf= rev_nfs[0] |
|
|
self.ups = nn.ModuleList() |
|
|
for i in range(len(rev_nfs)): |
|
|
prev_nf =nf |
|
|
nf = rev_nfs[i] |
|
|
ni = rev_nfs[min(i+1, len(nfs)-1)] |
|
|
self.ups.append(UpBlock(n_emb, ni, prev_nf, nf, add_up=i!=len(nfs)-1, num_layers=n_layers+1)) |
|
|
|
|
|
self.conv_out = unet_conv(ni=nfs[0],nf=out_channels,norm=nn.BatchNorm2d,act=nn.SiLU,bias=False) |
|
|
|
|
|
def forward(self,inp): |
|
|
x,t,c = inp |
|
|
temb = timestep_embedding(t,self.n_temb) |
|
|
emb = self.mlp_emb(temb) + self.cond_emb(c) |
|
|
x = self.conv_in(x) |
|
|
saved = [x] |
|
|
for block in self.downs: |
|
|
x= block(x,emb) |
|
|
saved += [p for o in self.downs for p in o.saved] |
|
|
x = self.mid_block(x, emb) |
|
|
for block in self.ups: x = block(x, emb, saved) |
|
|
return self.conv_out(x) |
|
|
|
|
|
import torch |
|
|
import matplotlib.pyplot as plt |
|
|
import torchvision.transforms.functional as TF |
|
|
|
|
|
|
|
|
LABELS = [ |
|
|
'T-shirt/top', |
|
|
'Trouser', |
|
|
'Pullover', |
|
|
'Dress', |
|
|
'Coat', |
|
|
'Sandal', |
|
|
'Shirt', |
|
|
'Sneaker', |
|
|
'Bag', |
|
|
'Ankle boot' |
|
|
] |
|
|
|
|
|
def sigmas_karras(n, sigma_min=0.01, sigma_max=80., rho=7.): |
|
|
ramp = torch.linspace(0, 1, n) |
|
|
min_inv_rho = sigma_min**(1/rho) |
|
|
max_inv_rho = sigma_max**(1/rho) |
|
|
sigmas = (max_inv_rho + ramp * (min_inv_rho-max_inv_rho))**rho |
|
|
return torch.cat([sigmas, torch.tensor([0.])]) |
|
|
|
|
|
def scalings(sig, sig_data=0.66): |
|
|
totvar = sig**2 + sig_data**2 |
|
|
c_skip = sig_data**2 / totvar |
|
|
c_out = sig * sig_data / totvar.sqrt() |
|
|
c_in = 1 / totvar.sqrt() |
|
|
return c_skip, c_out, c_in |
|
|
|
|
|
def denoise(model, x, sig, label, device): |
|
|
sig = sig[None].to(device) |
|
|
c_skip, c_out, c_in = scalings(sig) |
|
|
return model((x * c_in, sig, torch.tensor([label], device=device))) * c_out + x * c_skip |
|
|
|
|
|
@torch.no_grad() |
|
|
def sample_euler(x, sigs, i, model, label, device): |
|
|
sig, sig2 = sigs[i], sigs[i+1] |
|
|
denoised = denoise(model, x, sig, label, device) |
|
|
return x + (x - denoised) / sig * (sig2 - sig) |
|
|
|
|
|
def generate(class_name, model, steps=100, sigma_max=80., device=None): |
|
|
|
|
|
if device is None: |
|
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
|
|
|
class_name_normalized = class_name.strip().lower() |
|
|
|
|
|
label_idx = None |
|
|
for idx, label in enumerate(LABELS): |
|
|
if label.lower() == class_name_normalized: |
|
|
label_idx = idx |
|
|
break |
|
|
|
|
|
if label_idx is None: |
|
|
available = "\n".join([f" {i}: {name}" for i, name in enumerate(LABELS)]) |
|
|
raise ValueError( |
|
|
f"Invalid class name: '{class_name}'\n\n" |
|
|
f"Available classes:\n{available}" |
|
|
) |
|
|
|
|
|
model.eval() |
|
|
model.to(device) |
|
|
|
|
|
x = torch.randn(1, 1, 32, 32).to(device) * sigma_max |
|
|
|
|
|
sigs = sigmas_karras(steps, sigma_max=sigma_max).to(device) |
|
|
|
|
|
for i in range(len(sigs) - 1): |
|
|
x = sample_euler(x, sigs, i, model, label_idx, device) |
|
|
|
|
|
return x.squeeze(0).cpu() |
|
|
|
|
|
|
|
|
def show_image(image_tensor): |
|
|
""" |
|
|
Display a generated image. |
|
|
|
|
|
Args: |
|
|
image_tensor: Tensor of shape [1, 32, 32] or [32, 32] |
|
|
""" |
|
|
if image_tensor.dim() == 3: |
|
|
image_tensor = image_tensor.squeeze(0) |
|
|
|
|
|
plt.figure(figsize=(4, 4)) |
|
|
plt.imshow(image_tensor, cmap='gray', vmin=-1, vmax=1) |
|
|
plt.axis('off') |
|
|
plt.tight_layout() |
|
|
plt.show() |