# -*- coding: utf-8 -*- """model.ipynb Automatically generated by Colab. Original file is located at https://colab.research.google.com/drive/1_9Ac8If0MzzNMrKBdneReuZPjByJNzHE """ import torch from torch import nn import torch.nn.functional as F # -------------------------------------------------- # UNet for Text-to-Digit Diffusion (MNIST) # -------------------------------------------------- class UNet(nn.Module): def __init__(self, emb_dim: int = 64, num_classes: int = 10, timesteps: int = 100): super().__init__() # Time embedding self.time_mlp = nn.Sequential( nn.Linear(1, emb_dim), nn.ReLU(), nn.Linear(emb_dim, emb_dim) ) # Label embedding self.label_embed = nn.Embedding(num_classes, emb_dim) # Encoder self.enc1 = nn.Conv2d(1, 32, 3, padding=1) self.enc2 = nn.Conv2d(32, 64, 3, stride=2, padding=1) self.enc3 = nn.Conv2d(64, 128, 3, stride=2, padding=1) # Bottleneck self.bot = nn.Conv2d(128, 128, 3, padding=1) # Conditioning projection self.cond_proj = nn.Linear(emb_dim, 128) # Decoder self.dec3 = nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1) self.dec2 = nn.ConvTranspose2d(64, 32, 4, stride=2, padding=1) self.dec1 = nn.Conv2d(32, 1, 3, padding=1) self.timesteps = timesteps def forward(self, x, t, labels): # Time embedding t = t.unsqueeze(-1).float() / self.timesteps t_emb = self.time_mlp(t) # Label embedding l_emb = self.label_embed(labels) # Conditioning cond = t_emb + l_emb cond = self.cond_proj(cond).unsqueeze(-1).unsqueeze(-1) # Encoder x1 = F.relu(self.enc1(x)) x2 = F.relu(self.enc2(x1)) x3 = F.relu(self.enc3(x2)) # Bottleneck + conditioning h = F.relu(self.bot(x3 + cond)) # Decoder with skip connections h = F.relu(self.dec3(h)) + x2 h = F.relu(self.dec2(h)) + x1 return self.dec1(h) # -------------------------------------------------- # Factory function (EffNet-style) # -------------------------------------------------- def create_diffusion_model( emb_dim: int = 64, num_classes: int = 10, timesteps: int = 100, seed: int = 42, device: str = "cpu" ): """ Creates a conditional diffusion UNet model. Returns: model (nn.Module): diffusion UNet """ # Reproducibility torch.manual_seed(seed) model = UNet( emb_dim=emb_dim, num_classes=num_classes, timesteps=timesteps ).to(device) return model