import torch from torch import nn import torch.nn.functional as F import math from huggingface_hub import PyTorchModelHubMixin def timestep_embedding(tsteps, emb_dim, max_period= 10000): exponent = -math.log(max_period) * torch.linspace(0, 1, emb_dim//2, device=tsteps.device) emb = tsteps[:,None].float() * exponent.exp()[None,:] emb = torch.cat([emb.sin(), emb.cos()], dim=-1) return F.pad(emb, (0,1,0,0)) if emb_dim%2==1 else emb def lin(ni,nf,act=nn.SiLU,norm=None,bias=True): layers = nn.Sequential() if norm: layers.append(norm(ni)) if act: layers.append(act()) layers.append(nn.Linear(ni,nf,bias=bias)) return layers def unet_conv(ni,nf,act=nn.SiLU,norm=None,bias=True,ks=3,stride=1): layers = nn.Sequential() if norm: layers.append(norm(ni)) if act: layers.append(act()) layers.append(nn.Conv2d(ni,nf,kernel_size=ks,stride=stride,padding=ks//2,bias=bias)) return layers class EmbResBlock(nn.Module): def __init__(self, n_emb,ni,nf,act=nn.SiLU,norm=nn.BatchNorm2d,bias=True,ks=3): super().__init__() if nf is None: nf = ni self.emb_proj = nn.Linear(n_emb,nf*2) self.conv1 = unet_conv(ni,nf,norm=norm,act=act,ks=ks) self.conv2 = unet_conv(nf,nf,norm=norm,act=act,ks=ks) self.idconv = nn.Identity() if ni==nf else nn.Conv2d(ni,nf,1) def forward(self,x,t): inp = x emb = self.emb_proj(F.silu(t))[:, :, None, None] x = self.conv1(x) scale,shift = torch.chunk(emb,2,dim=1) x = x*(1+scale) + shift x= self.conv2(x) return x + self.idconv(inp) from functools import wraps def saved(m, blk): m_ = m.forward @wraps(m.forward) def _f(*args, **kwargs): res = m_(*args, **kwargs) blk.saved.append(res) return res m.forward = _f return m class DownBlock(nn.Module): def __init__(self, n_emb,ni,nf,add_down,num_layers=1): super().__init__() self.resnets =nn.ModuleList([saved(EmbResBlock(n_emb,ni if i ==0 else nf,nf),self) for i in range(num_layers)]) if add_down: self.down = saved(nn.Conv2d(nf, nf, 3, stride=2, padding=1),self) else: self.down = nn.Identity() def forward(self,x,t): self.saved = [] for resnet in self.resnets: x = resnet(x,t) x= self.down(x) return x def upsample(nf): return nn.Sequential(nn.Upsample(scale_factor=2.), nn.Conv2d(nf, nf, 3, padding=1)) class UpBlock(nn.Module): def __init__(self, n_emb,ni,prev_nf,nf,add_up,num_layers = 1): super().__init__() self.resnets = nn.ModuleList([EmbResBlock(n_emb,(prev_nf if i==0 else nf)+(ni if (i==num_layers-1) else nf),nf) for i in range(num_layers)]) self.up = upsample(nf) if add_up else nn.Identity() def forward(self,x,t,ups): for resnet in self.resnets: x= resnet(torch.cat([x,ups.pop()],dim=1),t) x = self.up(x) return x class EmbUnetModel(nn.Module, PyTorchModelHubMixin): def __init__(self, in_channels =1,out_channels=1,nfs=(32,64,128,256),n_layers=2): super().__init__() self.conv_in = nn.Conv2d(in_channels,nfs[0],kernel_size=3,padding=1) self.n_temb = nf = nfs[0] n_emb = nf*4 self.cond_emb = nn.Embedding(10, n_emb) self.mlp_emb = nn.Sequential(lin(self.n_temb,n_emb,norm=nn.BatchNorm1d),lin(n_emb,n_emb)) self.downs = nn.ModuleList() for i in range(len(nfs)): ni=nf nf= nfs[i] self.downs.append(DownBlock(n_emb,ni,nf,add_down=i!=len(nfs)-1,num_layers=n_layers)) self.mid_block = EmbResBlock(n_emb,ni=nfs[-1],nf=None) rev_nfs = list(reversed(nfs)) nf= rev_nfs[0] self.ups = nn.ModuleList() for i in range(len(rev_nfs)): prev_nf =nf nf = rev_nfs[i] ni = rev_nfs[min(i+1, len(nfs)-1)] self.ups.append(UpBlock(n_emb, ni, prev_nf, nf, add_up=i!=len(nfs)-1, num_layers=n_layers+1)) self.conv_out = unet_conv(ni=nfs[0],nf=out_channels,norm=nn.BatchNorm2d,act=nn.SiLU,bias=False) def forward(self,inp): x,t,c = inp temb = timestep_embedding(t,self.n_temb) emb = self.mlp_emb(temb) + self.cond_emb(c) x = self.conv_in(x) saved = [x] for block in self.downs: x= block(x,emb) saved += [p for o in self.downs for p in o.saved] x = self.mid_block(x, emb) for block in self.ups: x = block(x, emb, saved) return self.conv_out(x) import torch import matplotlib.pyplot as plt import torchvision.transforms.functional as TF # All labels from dataset LABELS = [ 'T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot' ] def sigmas_karras(n, sigma_min=0.01, sigma_max=80., rho=7.): ramp = torch.linspace(0, 1, n) min_inv_rho = sigma_min**(1/rho) max_inv_rho = sigma_max**(1/rho) sigmas = (max_inv_rho + ramp * (min_inv_rho-max_inv_rho))**rho return torch.cat([sigmas, torch.tensor([0.])]) def scalings(sig, sig_data=0.66): totvar = sig**2 + sig_data**2 c_skip = sig_data**2 / totvar c_out = sig * sig_data / totvar.sqrt() c_in = 1 / totvar.sqrt() return c_skip, c_out, c_in def denoise(model, x, sig, label, device): sig = sig[None].to(device) c_skip, c_out, c_in = scalings(sig) return model((x * c_in, sig, torch.tensor([label], device=device))) * c_out + x * c_skip @torch.no_grad() def sample_euler(x, sigs, i, model, label, device): sig, sig2 = sigs[i], sigs[i+1] denoised = denoise(model, x, sig, label, device) return x + (x - denoised) / sig * (sig2 - sig) def generate(class_name, model, steps=100, sigma_max=80., device=None): if device is None: device = 'cuda' if torch.cuda.is_available() else 'cpu' class_name_normalized = class_name.strip().lower() label_idx = None for idx, label in enumerate(LABELS): if label.lower() == class_name_normalized: label_idx = idx break if label_idx is None: available = "\n".join([f" {i}: {name}" for i, name in enumerate(LABELS)]) raise ValueError( f"Invalid class name: '{class_name}'\n\n" f"Available classes:\n{available}" ) model.eval() model.to(device) x = torch.randn(1, 1, 32, 32).to(device) * sigma_max sigs = sigmas_karras(steps, sigma_max=sigma_max).to(device) for i in range(len(sigs) - 1): x = sample_euler(x, sigs, i, model, label_idx, device) return x.squeeze(0).cpu() def show_image(image_tensor): """ Display a generated image. Args: image_tensor: Tensor of shape [1, 32, 32] or [32, 32] """ if image_tensor.dim() == 3: image_tensor = image_tensor.squeeze(0) plt.figure(figsize=(4, 4)) plt.imshow(image_tensor, cmap='gray', vmin=-1, vmax=1) plt.axis('off') plt.tight_layout() plt.show()