Sherwinroger002's picture
Update modeling.py
384324d verified
import torch
from torch import nn
import torch.nn.functional as F
import math
from huggingface_hub import PyTorchModelHubMixin
def timestep_embedding(tsteps, emb_dim, max_period= 10000):
exponent = -math.log(max_period) * torch.linspace(0, 1, emb_dim//2, device=tsteps.device)
emb = tsteps[:,None].float() * exponent.exp()[None,:]
emb = torch.cat([emb.sin(), emb.cos()], dim=-1)
return F.pad(emb, (0,1,0,0)) if emb_dim%2==1 else emb
def lin(ni,nf,act=nn.SiLU,norm=None,bias=True):
layers = nn.Sequential()
if norm:
layers.append(norm(ni))
if act:
layers.append(act())
layers.append(nn.Linear(ni,nf,bias=bias))
return layers
def unet_conv(ni,nf,act=nn.SiLU,norm=None,bias=True,ks=3,stride=1):
layers = nn.Sequential()
if norm:
layers.append(norm(ni))
if act:
layers.append(act())
layers.append(nn.Conv2d(ni,nf,kernel_size=ks,stride=stride,padding=ks//2,bias=bias))
return layers
class EmbResBlock(nn.Module):
def __init__(self, n_emb,ni,nf,act=nn.SiLU,norm=nn.BatchNorm2d,bias=True,ks=3):
super().__init__()
if nf is None:
nf = ni
self.emb_proj = nn.Linear(n_emb,nf*2)
self.conv1 = unet_conv(ni,nf,norm=norm,act=act,ks=ks)
self.conv2 = unet_conv(nf,nf,norm=norm,act=act,ks=ks)
self.idconv = nn.Identity() if ni==nf else nn.Conv2d(ni,nf,1)
def forward(self,x,t):
inp = x
emb = self.emb_proj(F.silu(t))[:, :, None, None]
x = self.conv1(x)
scale,shift = torch.chunk(emb,2,dim=1)
x = x*(1+scale) + shift
x= self.conv2(x)
return x + self.idconv(inp)
from functools import wraps
def saved(m, blk):
m_ = m.forward
@wraps(m.forward)
def _f(*args, **kwargs):
res = m_(*args, **kwargs)
blk.saved.append(res)
return res
m.forward = _f
return m
class DownBlock(nn.Module):
def __init__(self, n_emb,ni,nf,add_down,num_layers=1):
super().__init__()
self.resnets =nn.ModuleList([saved(EmbResBlock(n_emb,ni if i ==0 else nf,nf),self) for i in range(num_layers)])
if add_down:
self.down = saved(nn.Conv2d(nf, nf, 3, stride=2, padding=1),self)
else:
self.down = nn.Identity()
def forward(self,x,t):
self.saved = []
for resnet in self.resnets:
x = resnet(x,t)
x= self.down(x)
return x
def upsample(nf): return nn.Sequential(nn.Upsample(scale_factor=2.), nn.Conv2d(nf, nf, 3, padding=1))
class UpBlock(nn.Module):
def __init__(self, n_emb,ni,prev_nf,nf,add_up,num_layers = 1):
super().__init__()
self.resnets = nn.ModuleList([EmbResBlock(n_emb,(prev_nf if i==0 else nf)+(ni if (i==num_layers-1) else nf),nf) for i in range(num_layers)])
self.up = upsample(nf) if add_up else nn.Identity()
def forward(self,x,t,ups):
for resnet in self.resnets:
x= resnet(torch.cat([x,ups.pop()],dim=1),t)
x = self.up(x)
return x
class EmbUnetModel(nn.Module, PyTorchModelHubMixin):
def __init__(self, in_channels =1,out_channels=1,nfs=(32,64,128,256),n_layers=2):
super().__init__()
self.conv_in = nn.Conv2d(in_channels,nfs[0],kernel_size=3,padding=1)
self.n_temb = nf = nfs[0]
n_emb = nf*4
self.cond_emb = nn.Embedding(10, n_emb)
self.mlp_emb = nn.Sequential(lin(self.n_temb,n_emb,norm=nn.BatchNorm1d),lin(n_emb,n_emb))
self.downs = nn.ModuleList()
for i in range(len(nfs)):
ni=nf
nf= nfs[i]
self.downs.append(DownBlock(n_emb,ni,nf,add_down=i!=len(nfs)-1,num_layers=n_layers))
self.mid_block = EmbResBlock(n_emb,ni=nfs[-1],nf=None)
rev_nfs = list(reversed(nfs))
nf= rev_nfs[0]
self.ups = nn.ModuleList()
for i in range(len(rev_nfs)):
prev_nf =nf
nf = rev_nfs[i]
ni = rev_nfs[min(i+1, len(nfs)-1)]
self.ups.append(UpBlock(n_emb, ni, prev_nf, nf, add_up=i!=len(nfs)-1, num_layers=n_layers+1))
self.conv_out = unet_conv(ni=nfs[0],nf=out_channels,norm=nn.BatchNorm2d,act=nn.SiLU,bias=False)
def forward(self,inp):
x,t,c = inp
temb = timestep_embedding(t,self.n_temb)
emb = self.mlp_emb(temb) + self.cond_emb(c)
x = self.conv_in(x)
saved = [x]
for block in self.downs:
x= block(x,emb)
saved += [p for o in self.downs for p in o.saved]
x = self.mid_block(x, emb)
for block in self.ups: x = block(x, emb, saved)
return self.conv_out(x)
import torch
import matplotlib.pyplot as plt
import torchvision.transforms.functional as TF
# All labels from dataset
LABELS = [
'T-shirt/top',
'Trouser',
'Pullover',
'Dress',
'Coat',
'Sandal',
'Shirt',
'Sneaker',
'Bag',
'Ankle boot'
]
def sigmas_karras(n, sigma_min=0.01, sigma_max=80., rho=7.):
ramp = torch.linspace(0, 1, n)
min_inv_rho = sigma_min**(1/rho)
max_inv_rho = sigma_max**(1/rho)
sigmas = (max_inv_rho + ramp * (min_inv_rho-max_inv_rho))**rho
return torch.cat([sigmas, torch.tensor([0.])])
def scalings(sig, sig_data=0.66):
totvar = sig**2 + sig_data**2
c_skip = sig_data**2 / totvar
c_out = sig * sig_data / totvar.sqrt()
c_in = 1 / totvar.sqrt()
return c_skip, c_out, c_in
def denoise(model, x, sig, label, device):
sig = sig[None].to(device)
c_skip, c_out, c_in = scalings(sig)
return model((x * c_in, sig, torch.tensor([label], device=device))) * c_out + x * c_skip
@torch.no_grad()
def sample_euler(x, sigs, i, model, label, device):
sig, sig2 = sigs[i], sigs[i+1]
denoised = denoise(model, x, sig, label, device)
return x + (x - denoised) / sig * (sig2 - sig)
def generate(class_name, model, steps=100, sigma_max=80., device=None):
if device is None:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
class_name_normalized = class_name.strip().lower()
label_idx = None
for idx, label in enumerate(LABELS):
if label.lower() == class_name_normalized:
label_idx = idx
break
if label_idx is None:
available = "\n".join([f" {i}: {name}" for i, name in enumerate(LABELS)])
raise ValueError(
f"Invalid class name: '{class_name}'\n\n"
f"Available classes:\n{available}"
)
model.eval()
model.to(device)
x = torch.randn(1, 1, 32, 32).to(device) * sigma_max
sigs = sigmas_karras(steps, sigma_max=sigma_max).to(device)
for i in range(len(sigs) - 1):
x = sample_euler(x, sigs, i, model, label_idx, device)
return x.squeeze(0).cpu()
def show_image(image_tensor):
"""
Display a generated image.
Args:
image_tensor: Tensor of shape [1, 32, 32] or [32, 32]
"""
if image_tensor.dim() == 3:
image_tensor = image_tensor.squeeze(0)
plt.figure(figsize=(4, 4))
plt.imshow(image_tensor, cmap='gray', vmin=-1, vmax=1)
plt.axis('off')
plt.tight_layout()
plt.show()