Spaces:
Running
Running
File size: 2,691 Bytes
88ba839 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 | # -*- coding: utf-8 -*-
"""model.ipynb
Automatically generated by Colab.
Original file is located at
https://colab.research.google.com/drive/1_9Ac8If0MzzNMrKBdneReuZPjByJNzHE
"""
import torch
from torch import nn
import torch.nn.functional as F
# --------------------------------------------------
# UNet for Text-to-Digit Diffusion (MNIST)
# --------------------------------------------------
class UNet(nn.Module):
def __init__(self, emb_dim: int = 64, num_classes: int = 10, timesteps: int = 100):
super().__init__()
# Time embedding
self.time_mlp = nn.Sequential(
nn.Linear(1, emb_dim),
nn.ReLU(),
nn.Linear(emb_dim, emb_dim)
)
# Label embedding
self.label_embed = nn.Embedding(num_classes, emb_dim)
# Encoder
self.enc1 = nn.Conv2d(1, 32, 3, padding=1)
self.enc2 = nn.Conv2d(32, 64, 3, stride=2, padding=1)
self.enc3 = nn.Conv2d(64, 128, 3, stride=2, padding=1)
# Bottleneck
self.bot = nn.Conv2d(128, 128, 3, padding=1)
# Conditioning projection
self.cond_proj = nn.Linear(emb_dim, 128)
# Decoder
self.dec3 = nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1)
self.dec2 = nn.ConvTranspose2d(64, 32, 4, stride=2, padding=1)
self.dec1 = nn.Conv2d(32, 1, 3, padding=1)
self.timesteps = timesteps
def forward(self, x, t, labels):
# Time embedding
t = t.unsqueeze(-1).float() / self.timesteps
t_emb = self.time_mlp(t)
# Label embedding
l_emb = self.label_embed(labels)
# Conditioning
cond = t_emb + l_emb
cond = self.cond_proj(cond).unsqueeze(-1).unsqueeze(-1)
# Encoder
x1 = F.relu(self.enc1(x))
x2 = F.relu(self.enc2(x1))
x3 = F.relu(self.enc3(x2))
# Bottleneck + conditioning
h = F.relu(self.bot(x3 + cond))
# Decoder with skip connections
h = F.relu(self.dec3(h)) + x2
h = F.relu(self.dec2(h)) + x1
return self.dec1(h)
# --------------------------------------------------
# Factory function (EffNet-style)
# --------------------------------------------------
def create_diffusion_model(
emb_dim: int = 64,
num_classes: int = 10,
timesteps: int = 100,
seed: int = 42,
device: str = "cpu"
):
"""
Creates a conditional diffusion UNet model.
Returns:
model (nn.Module): diffusion UNet
"""
# Reproducibility
torch.manual_seed(seed)
model = UNet(
emb_dim=emb_dim,
num_classes=num_classes,
timesteps=timesteps
).to(device)
return model
|