diffusion-model / model.py
Dulamaa's picture
intital commit
88ba839
Raw
History Blame Contribute Delete
2.69 kB
# -*- coding: utf-8 -*-
"""model.ipynb
Automatically generated by Colab.
Original file is located at
https://colab.research.google.com/drive/1_9Ac8If0MzzNMrKBdneReuZPjByJNzHE
"""
import torch
from torch import nn
import torch.nn.functional as F
# --------------------------------------------------
# UNet for Text-to-Digit Diffusion (MNIST)
# --------------------------------------------------
class UNet(nn.Module):
def __init__(self, emb_dim: int = 64, num_classes: int = 10, timesteps: int = 100):
super().__init__()
# Time embedding
self.time_mlp = nn.Sequential(
nn.Linear(1, emb_dim),
nn.ReLU(),
nn.Linear(emb_dim, emb_dim)
)
# Label embedding
self.label_embed = nn.Embedding(num_classes, emb_dim)
# Encoder
self.enc1 = nn.Conv2d(1, 32, 3, padding=1)
self.enc2 = nn.Conv2d(32, 64, 3, stride=2, padding=1)
self.enc3 = nn.Conv2d(64, 128, 3, stride=2, padding=1)
# Bottleneck
self.bot = nn.Conv2d(128, 128, 3, padding=1)
# Conditioning projection
self.cond_proj = nn.Linear(emb_dim, 128)
# Decoder
self.dec3 = nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1)
self.dec2 = nn.ConvTranspose2d(64, 32, 4, stride=2, padding=1)
self.dec1 = nn.Conv2d(32, 1, 3, padding=1)
self.timesteps = timesteps
def forward(self, x, t, labels):
# Time embedding
t = t.unsqueeze(-1).float() / self.timesteps
t_emb = self.time_mlp(t)
# Label embedding
l_emb = self.label_embed(labels)
# Conditioning
cond = t_emb + l_emb
cond = self.cond_proj(cond).unsqueeze(-1).unsqueeze(-1)
# Encoder
x1 = F.relu(self.enc1(x))
x2 = F.relu(self.enc2(x1))
x3 = F.relu(self.enc3(x2))
# Bottleneck + conditioning
h = F.relu(self.bot(x3 + cond))
# Decoder with skip connections
h = F.relu(self.dec3(h)) + x2
h = F.relu(self.dec2(h)) + x1
return self.dec1(h)
# --------------------------------------------------
# Factory function (EffNet-style)
# --------------------------------------------------
def create_diffusion_model(
emb_dim: int = 64,
num_classes: int = 10,
timesteps: int = 100,
seed: int = 42,
device: str = "cpu"
):
"""
Creates a conditional diffusion UNet model.
Returns:
model (nn.Module): diffusion UNet
"""
# Reproducibility
torch.manual_seed(seed)
model = UNet(
emb_dim=emb_dim,
num_classes=num_classes,
timesteps=timesteps
).to(device)
return model