Spaces:
Running
Running
| # -*- coding: utf-8 -*- | |
| """model.ipynb | |
| Automatically generated by Colab. | |
| Original file is located at | |
| https://colab.research.google.com/drive/1_9Ac8If0MzzNMrKBdneReuZPjByJNzHE | |
| """ | |
| import torch | |
| from torch import nn | |
| import torch.nn.functional as F | |
| # -------------------------------------------------- | |
| # UNet for Text-to-Digit Diffusion (MNIST) | |
| # -------------------------------------------------- | |
| class UNet(nn.Module): | |
| def __init__(self, emb_dim: int = 64, num_classes: int = 10, timesteps: int = 100): | |
| super().__init__() | |
| # Time embedding | |
| self.time_mlp = nn.Sequential( | |
| nn.Linear(1, emb_dim), | |
| nn.ReLU(), | |
| nn.Linear(emb_dim, emb_dim) | |
| ) | |
| # Label embedding | |
| self.label_embed = nn.Embedding(num_classes, emb_dim) | |
| # Encoder | |
| self.enc1 = nn.Conv2d(1, 32, 3, padding=1) | |
| self.enc2 = nn.Conv2d(32, 64, 3, stride=2, padding=1) | |
| self.enc3 = nn.Conv2d(64, 128, 3, stride=2, padding=1) | |
| # Bottleneck | |
| self.bot = nn.Conv2d(128, 128, 3, padding=1) | |
| # Conditioning projection | |
| self.cond_proj = nn.Linear(emb_dim, 128) | |
| # Decoder | |
| self.dec3 = nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1) | |
| self.dec2 = nn.ConvTranspose2d(64, 32, 4, stride=2, padding=1) | |
| self.dec1 = nn.Conv2d(32, 1, 3, padding=1) | |
| self.timesteps = timesteps | |
| def forward(self, x, t, labels): | |
| # Time embedding | |
| t = t.unsqueeze(-1).float() / self.timesteps | |
| t_emb = self.time_mlp(t) | |
| # Label embedding | |
| l_emb = self.label_embed(labels) | |
| # Conditioning | |
| cond = t_emb + l_emb | |
| cond = self.cond_proj(cond).unsqueeze(-1).unsqueeze(-1) | |
| # Encoder | |
| x1 = F.relu(self.enc1(x)) | |
| x2 = F.relu(self.enc2(x1)) | |
| x3 = F.relu(self.enc3(x2)) | |
| # Bottleneck + conditioning | |
| h = F.relu(self.bot(x3 + cond)) | |
| # Decoder with skip connections | |
| h = F.relu(self.dec3(h)) + x2 | |
| h = F.relu(self.dec2(h)) + x1 | |
| return self.dec1(h) | |
| # -------------------------------------------------- | |
| # Factory function (EffNet-style) | |
| # -------------------------------------------------- | |
| def create_diffusion_model( | |
| emb_dim: int = 64, | |
| num_classes: int = 10, | |
| timesteps: int = 100, | |
| seed: int = 42, | |
| device: str = "cpu" | |
| ): | |
| """ | |
| Creates a conditional diffusion UNet model. | |
| Returns: | |
| model (nn.Module): diffusion UNet | |
| """ | |
| # Reproducibility | |
| torch.manual_seed(seed) | |
| model = UNet( | |
| emb_dim=emb_dim, | |
| num_classes=num_classes, | |
| timesteps=timesteps | |
| ).to(device) | |
| return model | |