AlexWortega commited on
Commit
54c8086
·
verified ·
1 Parent(s): 51a44ad

Upload unet.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. unet.py +229 -0
unet.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ U-Net architecture for conditional diffusion on spatiotemporal PDE data.
3
+ Supports non-square inputs, time conditioning, and skip connections.
4
+ """
5
+ import math
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+
11
+ class SinusoidalPosEmb(nn.Module):
12
+ """Sinusoidal positional embedding for diffusion timestep."""
13
+
14
+ def __init__(self, dim):
15
+ super().__init__()
16
+ self.dim = dim
17
+
18
+ def forward(self, t):
19
+ half = self.dim // 2
20
+ emb = math.log(10000) / (half - 1)
21
+ emb = torch.exp(torch.arange(half, device=t.device) * -emb)
22
+ emb = t[:, None].float() * emb[None, :]
23
+ return torch.cat([emb.sin(), emb.cos()], dim=-1)
24
+
25
+
26
+ class ResBlock(nn.Module):
27
+ """Residual block with group norm, SiLU, and time embedding injection."""
28
+
29
+ def __init__(self, in_ch, out_ch, time_dim, dropout=0.1):
30
+ super().__init__()
31
+ self.norm1 = nn.GroupNorm(min(32, in_ch), in_ch)
32
+ self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
33
+ self.time_mlp = nn.Sequential(nn.SiLU(), nn.Linear(time_dim, out_ch))
34
+ self.norm2 = nn.GroupNorm(min(32, out_ch), out_ch)
35
+ self.dropout = nn.Dropout(dropout)
36
+ self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
37
+ self.skip = nn.Conv2d(in_ch, out_ch, 1) if in_ch != out_ch else nn.Identity()
38
+
39
+ def forward(self, x, t_emb):
40
+ h = F.silu(self.norm1(x))
41
+ h = self.conv1(h)
42
+ h = h + self.time_mlp(t_emb)[:, :, None, None]
43
+ h = F.silu(self.norm2(h))
44
+ h = self.dropout(h)
45
+ h = self.conv2(h)
46
+ return h + self.skip(x)
47
+
48
+
49
+ class SelfAttention(nn.Module):
50
+ """Multi-head self-attention on spatial features."""
51
+
52
+ def __init__(self, channels, num_heads=4):
53
+ super().__init__()
54
+ self.norm = nn.GroupNorm(min(32, channels), channels)
55
+ self.attn = nn.MultiheadAttention(channels, num_heads, batch_first=True)
56
+
57
+ def forward(self, x):
58
+ B, C, H, W = x.shape
59
+ h = self.norm(x).reshape(B, C, H * W).permute(0, 2, 1)
60
+ h, _ = self.attn(h, h, h)
61
+ h = h.permute(0, 2, 1).reshape(B, C, H, W)
62
+ return x + h
63
+
64
+
65
+ class Downsample(nn.Module):
66
+ def __init__(self, ch):
67
+ super().__init__()
68
+ self.conv = nn.Conv2d(ch, ch, 3, stride=2, padding=1)
69
+
70
+ def forward(self, x):
71
+ return self.conv(x)
72
+
73
+
74
+ class Upsample(nn.Module):
75
+ def __init__(self, ch):
76
+ super().__init__()
77
+ self.conv = nn.Conv2d(ch, ch, 3, padding=1)
78
+
79
+ def forward(self, x):
80
+ x = F.interpolate(x, scale_factor=2, mode="nearest")
81
+ return self.conv(x)
82
+
83
+
84
+ class UNet(nn.Module):
85
+ """U-Net for conditional diffusion.
86
+
87
+ Condition (e.g. previous frame) is concatenated to the noisy input along
88
+ the channel dimension *before* being passed to forward(). So set
89
+ ``in_channels = output_channels + condition_channels``.
90
+
91
+ Args:
92
+ in_channels: noisy-target channels + condition channels.
93
+ out_channels: channels to predict (same as target).
94
+ base_ch: base channel width.
95
+ ch_mults: per-level channel multipliers.
96
+ n_res: residual blocks per level.
97
+ attn_levels: which levels get self-attention (0-indexed).
98
+ dropout: dropout rate.
99
+ time_dim: timestep embedding dimension.
100
+ """
101
+
102
+ def __init__(
103
+ self,
104
+ in_channels,
105
+ out_channels,
106
+ base_ch=64,
107
+ ch_mults=(1, 2, 4, 8),
108
+ n_res=2,
109
+ attn_levels=(3,),
110
+ dropout=0.1,
111
+ time_dim=256,
112
+ ):
113
+ super().__init__()
114
+ self.n_res = n_res
115
+ self.ch_mults = ch_mults
116
+
117
+ # --- time embedding ---
118
+ self.time_embed = nn.Sequential(
119
+ SinusoidalPosEmb(time_dim),
120
+ nn.Linear(time_dim, time_dim * 4),
121
+ nn.SiLU(),
122
+ nn.Linear(time_dim * 4, time_dim),
123
+ )
124
+
125
+ # --- input projection ---
126
+ self.input_conv = nn.Conv2d(in_channels, base_ch, 3, padding=1)
127
+
128
+ # --- downsampling path ---
129
+ self.downs = nn.ModuleList()
130
+ ch = base_ch
131
+ skip_chs = [ch] # track channel dims for skip connections
132
+
133
+ for lvl, mult in enumerate(ch_mults):
134
+ out_ch = base_ch * mult
135
+ for _ in range(n_res):
136
+ self.downs.append(
137
+ nn.ModuleDict(
138
+ {
139
+ "res": ResBlock(ch, out_ch, time_dim, dropout),
140
+ **(
141
+ {"attn": SelfAttention(out_ch)}
142
+ if lvl in attn_levels
143
+ else {}
144
+ ),
145
+ }
146
+ )
147
+ )
148
+ ch = out_ch
149
+ skip_chs.append(ch)
150
+ if lvl < len(ch_mults) - 1:
151
+ self.downs.append(nn.ModuleDict({"down": Downsample(ch)}))
152
+ skip_chs.append(ch)
153
+
154
+ # --- middle ---
155
+ self.mid_res1 = ResBlock(ch, ch, time_dim, dropout)
156
+ self.mid_attn = SelfAttention(ch)
157
+ self.mid_res2 = ResBlock(ch, ch, time_dim, dropout)
158
+
159
+ # --- upsampling path ---
160
+ self.ups = nn.ModuleList()
161
+ for lvl in reversed(range(len(ch_mults))):
162
+ out_ch = base_ch * ch_mults[lvl]
163
+ for _ in range(n_res + 1): # +1 to consume downsample skip
164
+ skip_ch = skip_chs.pop()
165
+ self.ups.append(
166
+ nn.ModuleDict(
167
+ {
168
+ "res": ResBlock(ch + skip_ch, out_ch, time_dim, dropout),
169
+ **(
170
+ {"attn": SelfAttention(out_ch)}
171
+ if lvl in attn_levels
172
+ else {}
173
+ ),
174
+ }
175
+ )
176
+ )
177
+ ch = out_ch
178
+ if lvl > 0:
179
+ self.ups.append(nn.ModuleDict({"up": Upsample(ch)}))
180
+
181
+ # --- output projection ---
182
+ self.out_norm = nn.GroupNorm(min(32, ch), ch)
183
+ self.out_conv = nn.Conv2d(ch, out_channels, 3, padding=1)
184
+
185
+ def forward(self, x, t, cond=None):
186
+ """
187
+ Args:
188
+ x: noisy target [B, C_out, H, W]
189
+ t: diffusion timestep [B] (int or float)
190
+ cond: condition [B, C_cond, H, W] (optional, concatenated)
191
+ Returns:
192
+ predicted noise [B, C_out, H, W]
193
+ """
194
+ if cond is not None:
195
+ x = torch.cat([x, cond], dim=1)
196
+
197
+ t_emb = self.time_embed(t)
198
+ h = self.input_conv(x)
199
+
200
+ # --- down ---
201
+ skips = [h]
202
+ for block in self.downs:
203
+ if "down" in block:
204
+ h = block["down"](h)
205
+ skips.append(h)
206
+ else:
207
+ h = block["res"](h, t_emb)
208
+ if "attn" in block:
209
+ h = block["attn"](h)
210
+ skips.append(h)
211
+
212
+ # --- middle ---
213
+ h = self.mid_res1(h, t_emb)
214
+ h = self.mid_attn(h)
215
+ h = self.mid_res2(h, t_emb)
216
+
217
+ # --- up ---
218
+ for block in self.ups:
219
+ if "up" in block:
220
+ h = block["up"](h)
221
+ else:
222
+ s = skips.pop()
223
+ h = torch.cat([h, s], dim=1)
224
+ h = block["res"](h, t_emb)
225
+ if "attn" in block:
226
+ h = block["attn"](h)
227
+
228
+ h = F.silu(self.out_norm(h))
229
+ return self.out_conv(h)