RinKana commited on
Commit
884fd0e
·
verified ·
1 Parent(s): 1de5050

Upload diffusion_model.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. diffusion_model.py +198 -0
diffusion_model.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import math
4
+
5
+ class SinusoidalPosEmb(nn.Module):
6
+ def __init__(self, dim):
7
+ super().__init__()
8
+ self.dim = dim
9
+
10
+ def forward(self, time):
11
+ device = time.device
12
+ half_dim = self.dim // 2
13
+ emb = math.log(10000) / (half_dim - 1)
14
+ emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
15
+ emb = time[:, None] * emb[None, :]
16
+ emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
17
+ return emb
18
+
19
+ class ResidualBlock(nn.Module):
20
+ def __init__(self, in_channels, out_channels, time_emb_dim):
21
+ super().__init__()
22
+
23
+ self.time_mlp = nn.Sequential(
24
+ nn.SiLU(),
25
+ nn.Linear(time_emb_dim, out_channels * 2)
26
+ )
27
+
28
+ self.conv1 = nn.Conv2d(in_channels, out_channels, 3, padding=1)
29
+ self.norm1 = nn.GroupNorm(32, out_channels)
30
+
31
+ self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1)
32
+ self.norm2 = nn.GroupNorm(32, out_channels)
33
+
34
+ self.residual_conv = nn.Conv2d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity()
35
+
36
+ self.act = nn.SiLU()
37
+
38
+ def forward(self, x, time_emb):
39
+ h = self.conv1(x)
40
+ h = self.norm1(h)
41
+ h = self.act(h)
42
+
43
+ # Add time embedding
44
+ time_emb = self.time_mlp(time_emb)
45
+ time_emb = time_emb[:, :, None, None]
46
+ scale, shift = time_emb.chunk(2, dim=1)
47
+ h = h * (scale + 1) + shift
48
+
49
+ h = self.conv2(h)
50
+ h = self.norm2(h)
51
+ h = self.act(h)
52
+
53
+ return h + self.residual_conv(x)
54
+
55
+ class SelfAttention(nn.Module):
56
+ def __init__(self, channels):
57
+ super().__init__()
58
+ self.norm = nn.GroupNorm(32, channels)
59
+ self.qkv = nn.Conv2d(channels, channels * 3, 1)
60
+ self.out = nn.Conv2d(channels, channels, 1)
61
+ self.scale = 1.0 / math.sqrt(channels)
62
+
63
+ def forward(self, x):
64
+ b, c, h, w = x.shape
65
+ h_norm = self.norm(x)
66
+ qkv = self.qkv(h_norm)
67
+ q, k, v = qkv.chunk(3, dim=1)
68
+
69
+ q = q.reshape(b, c, h * w).transpose(-2, -1)
70
+ k = k.reshape(b, c, h * w)
71
+ v = v.reshape(b, c, h * w).transpose(-2, -1)
72
+
73
+ attn = torch.softmax(q @ k * self.scale, dim=-1)
74
+ out = attn @ v
75
+ out = out.transpose(-2, -1).reshape(b, c, h, w)
76
+
77
+ return x + self.out(out)
78
+
79
+ class UNet(nn.Module):
80
+ def __init__(self, img_size=64, in_channels=3, out_channels=3, base_channels=128, ch_mult=(1, 2, 4)):
81
+ super().__init__()
82
+
83
+ self.time_embed = SinusoidalPosEmb(base_channels)
84
+ self.time_mlp = nn.Sequential(
85
+ nn.Linear(base_channels, base_channels * 4),
86
+ nn.SiLU(),
87
+ nn.Linear(base_channels * 4, base_channels * 4)
88
+ )
89
+
90
+ # Initial convolution
91
+ self.init_conv = nn.Conv2d(in_channels, base_channels, 3, padding=1)
92
+
93
+ # Downsampling - store channel dims for skip connections
94
+ self.down_channels = []
95
+ self.down_blocks = nn.ModuleList([])
96
+ channels = base_channels
97
+ for i, mult in enumerate(ch_mult):
98
+ out_ch = base_channels * mult
99
+ self.down_channels.append(out_ch)
100
+ self.down_blocks.append(nn.ModuleList([
101
+ ResidualBlock(channels, out_ch, base_channels * 4),
102
+ ResidualBlock(out_ch, out_ch, base_channels * 4),
103
+ ]))
104
+ channels = out_ch
105
+ if i < len(ch_mult) - 1:
106
+ self.down_blocks[-1].append(nn.Conv2d(out_ch, out_ch, 3, stride=2, padding=1))
107
+ else:
108
+ self.down_blocks[-1].append(nn.Identity())
109
+
110
+ # Bottleneck
111
+ self.bottleneck = nn.ModuleList([
112
+ ResidualBlock(channels, channels, base_channels * 4),
113
+ SelfAttention(channels),
114
+ ResidualBlock(channels, channels, base_channels * 4)
115
+ ])
116
+
117
+ # Upsampling
118
+ self.up_blocks = nn.ModuleList([])
119
+ for i, mult in reversed(list(enumerate(ch_mult))):
120
+ out_ch = base_channels * mult
121
+ # Skip connections: match corresponding down block
122
+ # up_block[i] connects to down_block[i] (same resolution)
123
+ skip_ch = self.down_channels[i]
124
+ in_ch = channels + skip_ch
125
+
126
+ self.up_blocks.append(nn.ModuleList([
127
+ ResidualBlock(in_ch, out_ch, base_channels * 4),
128
+ ResidualBlock(out_ch, out_ch, base_channels * 4),
129
+ ]))
130
+ channels = out_ch
131
+ if i > 0:
132
+ self.up_blocks[-1].append(nn.Upsample(scale_factor=2))
133
+ else:
134
+ self.up_blocks[-1].append(nn.Identity())
135
+
136
+ # Final convolution
137
+ self.final_conv = nn.Sequential(
138
+ nn.GroupNorm(32, base_channels),
139
+ nn.SiLU(),
140
+ nn.Conv2d(base_channels, out_channels, 3, padding=1)
141
+ )
142
+
143
+ def forward(self, x, t):
144
+ # Time embedding
145
+ t_emb = self.time_embed(t)
146
+ t_emb = self.time_mlp(t_emb)
147
+
148
+ # Initial conv
149
+ h = self.init_conv(x)
150
+
151
+ # Downsampling with skip connections
152
+ skips = []
153
+ for down_block in self.down_blocks:
154
+ res1, res2, downsample = down_block
155
+ h = res1(h, t_emb)
156
+ h = res2(h, t_emb)
157
+ skips.append(h)
158
+ h = downsample(h)
159
+
160
+ # Bottleneck
161
+ for layer in self.bottleneck:
162
+ if isinstance(layer, SelfAttention):
163
+ h = layer(h)
164
+ else:
165
+ h = layer(h, t_emb)
166
+
167
+ # Upsampling with skip connections
168
+ for i, up_block in enumerate(self.up_blocks):
169
+ res1, res2, upsample = up_block
170
+ # Concatenate skip connection (reverse order)
171
+ skip_idx = len(skips) - 1 - i
172
+ if skip_idx >= 0:
173
+ h = torch.cat([h, skips[skip_idx]], dim=1)
174
+ h = res1(h, t_emb)
175
+ h = res2(h, t_emb)
176
+ h = upsample(h)
177
+
178
+ return self.final_conv(h)
179
+
180
+ if __name__ == "__main__":
181
+ # Test the model with smaller batch size
182
+ print("Initializing UNet...")
183
+ model = UNet(img_size=64, base_channels=128)
184
+
185
+ total_params = sum(p.numel() for p in model.parameters())
186
+ print(f"Total parameters: {total_params:,}")
187
+
188
+ # Test with small batch
189
+ print("\nTesting forward pass...")
190
+ x = torch.randn(1, 3, 64, 64)
191
+ t = torch.randint(0, 1000, (1,))
192
+
193
+ with torch.no_grad():
194
+ output = model(x, t)
195
+
196
+ print(f"Input shape: {x.shape}")
197
+ print(f"Output shape: {output.shape}")
198
+ print("\nModel architecture verified successfully!")