File size: 2,691 Bytes
88ba839
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
# -*- coding: utf-8 -*-
"""model.ipynb

Automatically generated by Colab.

Original file is located at
    https://colab.research.google.com/drive/1_9Ac8If0MzzNMrKBdneReuZPjByJNzHE
"""

import torch
from torch import nn
import torch.nn.functional as F


# --------------------------------------------------
# UNet for Text-to-Digit Diffusion (MNIST)
# --------------------------------------------------
class UNet(nn.Module):
    def __init__(self, emb_dim: int = 64, num_classes: int = 10, timesteps: int = 100):
        super().__init__()

        # Time embedding
        self.time_mlp = nn.Sequential(
            nn.Linear(1, emb_dim),
            nn.ReLU(),
            nn.Linear(emb_dim, emb_dim)
        )

        # Label embedding
        self.label_embed = nn.Embedding(num_classes, emb_dim)

        # Encoder
        self.enc1 = nn.Conv2d(1, 32, 3, padding=1)
        self.enc2 = nn.Conv2d(32, 64, 3, stride=2, padding=1)
        self.enc3 = nn.Conv2d(64, 128, 3, stride=2, padding=1)

        # Bottleneck
        self.bot = nn.Conv2d(128, 128, 3, padding=1)

        # Conditioning projection
        self.cond_proj = nn.Linear(emb_dim, 128)

        # Decoder
        self.dec3 = nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1)
        self.dec2 = nn.ConvTranspose2d(64, 32, 4, stride=2, padding=1)
        self.dec1 = nn.Conv2d(32, 1, 3, padding=1)

        self.timesteps = timesteps

    def forward(self, x, t, labels):
        # Time embedding
        t = t.unsqueeze(-1).float() / self.timesteps
        t_emb = self.time_mlp(t)

        # Label embedding
        l_emb = self.label_embed(labels)

        # Conditioning
        cond = t_emb + l_emb
        cond = self.cond_proj(cond).unsqueeze(-1).unsqueeze(-1)

        # Encoder
        x1 = F.relu(self.enc1(x))
        x2 = F.relu(self.enc2(x1))
        x3 = F.relu(self.enc3(x2))

        # Bottleneck + conditioning
        h = F.relu(self.bot(x3 + cond))

        # Decoder with skip connections
        h = F.relu(self.dec3(h)) + x2
        h = F.relu(self.dec2(h)) + x1

        return self.dec1(h)


# --------------------------------------------------
# Factory function (EffNet-style)
# --------------------------------------------------
def create_diffusion_model(
    emb_dim: int = 64,
    num_classes: int = 10,
    timesteps: int = 100,
    seed: int = 42,
    device: str = "cpu"
):
    """
    Creates a conditional diffusion UNet model.

    Returns:
        model (nn.Module): diffusion UNet
    """

    # Reproducibility
    torch.manual_seed(seed)

    model = UNet(
        emb_dim=emb_dim,
        num_classes=num_classes,
        timesteps=timesteps
    ).to(device)

    return model