AbstractPhil commited on
Commit
e5d012d
Β·
verified Β·
1 Parent(s): a9031b9

Create modeling_flow_match.py

Browse files
Files changed (1) hide show
  1. modeling_flow_match.py +412 -0
modeling_flow_match.py ADDED
@@ -0,0 +1,412 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ FlowMatchRelay model β€” HuggingFace compatible.
3
+
4
+ Usage:
5
+ from transformers import AutoModel
6
+ model = AutoModel.from_pretrained(
7
+ "AbstractPhil/geolip-diffusion-proto",
8
+ trust_remote_code=True
9
+ )
10
+
11
+ # Generate samples
12
+ samples = model.sample(n_samples=8, class_label=3) # 8 cats
13
+ """
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+ import torch.nn.functional as F
18
+ import math
19
+ from transformers import PreTrainedModel
20
+ from .configuration_flow_match import FlowMatchRelayConfig
21
+
22
+
23
+ # ══════════════════════════════════════════════════════════════════
24
+ # CONSTELLATION RELAY
25
+ # ══════════════════════════════════════════════════════════════════
26
+
27
+ class ConstellationRelay(nn.Module):
28
+ """
29
+ Geometric regulator for feature maps.
30
+ Fixed anchors on S^(d-1), multi-phase stroboscope triangulation,
31
+ gated residual correction.
32
+ """
33
+ def __init__(self, channels, patch_dim=16, n_anchors=16, n_phases=3,
34
+ pw_hidden=32, gate_init=-3.0, mode='channel'):
35
+ super().__init__()
36
+ assert channels % patch_dim == 0
37
+ self.channels = channels
38
+ self.patch_dim = patch_dim
39
+ self.n_patches = channels // patch_dim
40
+ self.n_anchors = n_anchors
41
+ self.n_phases = n_phases
42
+ self.mode = mode
43
+
44
+ P, A, d = self.n_patches, n_anchors, patch_dim
45
+
46
+ home = torch.empty(P, A, d)
47
+ nn.init.xavier_normal_(home.view(P * A, d))
48
+ home = F.normalize(home.view(P, A, d), dim=-1)
49
+ self.register_buffer('home', home)
50
+ self.anchors = nn.Parameter(home.clone())
51
+
52
+ tri_dim = n_phases * A
53
+ self.pw_w1 = nn.Parameter(torch.empty(P, tri_dim, pw_hidden))
54
+ self.pw_b1 = nn.Parameter(torch.zeros(1, P, pw_hidden))
55
+ self.pw_w2 = nn.Parameter(torch.empty(P, pw_hidden, d))
56
+ self.pw_b2 = nn.Parameter(torch.zeros(1, P, d))
57
+ for p in range(P):
58
+ nn.init.xavier_normal_(self.pw_w1.data[p])
59
+ nn.init.xavier_normal_(self.pw_w2.data[p])
60
+ self.pw_norm = nn.LayerNorm(d)
61
+ self.gates = nn.Parameter(torch.full((P,), gate_init))
62
+ self.norm = nn.LayerNorm(channels)
63
+
64
+ def drift(self):
65
+ h, c = F.normalize(self.home, dim=-1), F.normalize(self.anchors, dim=-1)
66
+ return torch.acos((h * c).sum(-1).clamp(-1 + 1e-7, 1 - 1e-7))
67
+
68
+ def at_phase(self, t):
69
+ h, c = F.normalize(self.home, dim=-1), F.normalize(self.anchors, dim=-1)
70
+ omega = self.drift().unsqueeze(-1)
71
+ so = omega.sin().clamp(min=1e-7)
72
+ return torch.sin((1-t)*omega)/so * h + torch.sin(t*omega)/so * c
73
+
74
+ def _relay_core(self, x_flat):
75
+ N, C = x_flat.shape
76
+ P, A, d = self.n_patches, self.n_anchors, self.patch_dim
77
+ x_n = self.norm(x_flat)
78
+ patches = x_n.reshape(N, P, d)
79
+ patches_n = F.normalize(patches, dim=-1)
80
+ phases = torch.linspace(0, 1, self.n_phases, device=x_flat.device).tolist()
81
+ tris = []
82
+ for t in phases:
83
+ at = F.normalize(self.at_phase(t), dim=-1)
84
+ tris.append(1.0 - torch.einsum('npd,pad->npa', patches_n, at))
85
+ tri = torch.cat(tris, dim=-1)
86
+ h = F.gelu(torch.einsum('npt,pth->nph', tri, self.pw_w1) + self.pw_b1)
87
+ pw = self.pw_norm(torch.einsum('nph,phd->npd', h, self.pw_w2) + self.pw_b2)
88
+ g = self.gates.sigmoid().unsqueeze(0).unsqueeze(-1)
89
+ blended = g * pw + (1-g) * patches
90
+ return x_flat + blended.reshape(N, C)
91
+
92
+ def forward(self, x):
93
+ B, C, H, W = x.shape
94
+ if self.mode == 'channel':
95
+ pooled = x.mean(dim=(-2, -1))
96
+ relayed = self._relay_core(pooled)
97
+ scale = (relayed / (pooled + 1e-8)).unsqueeze(-1).unsqueeze(-1)
98
+ return x * scale.clamp(-3, 3)
99
+ else:
100
+ x_flat = x.permute(0, 2, 3, 1).reshape(B * H * W, C)
101
+ out = self._relay_core(x_flat)
102
+ return out.reshape(B, H, W, C).permute(0, 3, 1, 2)
103
+
104
+
105
+ # ══════════════════════════════════════════════════════════════════
106
+ # BUILDING BLOCKS
107
+ # ══════════════════════════════════════════════════════════════════
108
+
109
+ class SinusoidalPosEmb(nn.Module):
110
+ def __init__(self, dim):
111
+ super().__init__()
112
+ self.dim = dim
113
+
114
+ def forward(self, t):
115
+ half = self.dim // 2
116
+ emb = math.log(10000) / (half - 1)
117
+ emb = torch.exp(torch.arange(half, device=t.device, dtype=t.dtype) * -emb)
118
+ emb = t.unsqueeze(-1) * emb.unsqueeze(0)
119
+ return torch.cat([emb.sin(), emb.cos()], dim=-1)
120
+
121
+
122
+ class AdaGroupNorm(nn.Module):
123
+ def __init__(self, channels, cond_dim, n_groups=8):
124
+ super().__init__()
125
+ self.gn = nn.GroupNorm(min(n_groups, channels), channels, affine=False)
126
+ self.proj = nn.Linear(cond_dim, channels * 2)
127
+ nn.init.zeros_(self.proj.weight)
128
+ nn.init.zeros_(self.proj.bias)
129
+
130
+ def forward(self, x, cond):
131
+ x = self.gn(x)
132
+ scale, shift = self.proj(cond).unsqueeze(-1).unsqueeze(-1).chunk(2, dim=1)
133
+ return x * (1 + scale) + shift
134
+
135
+
136
+ class ConvBlock(nn.Module):
137
+ def __init__(self, channels, cond_dim, use_relay=False,
138
+ relay_patch_dim=16, relay_n_anchors=16, relay_n_phases=3,
139
+ relay_pw_hidden=32, relay_gate_init=-3.0, relay_mode='channel'):
140
+ super().__init__()
141
+ self.dw_conv = nn.Conv2d(channels, channels, 7, padding=3, groups=channels)
142
+ self.norm = AdaGroupNorm(channels, cond_dim)
143
+ self.pw1 = nn.Conv2d(channels, channels * 4, 1)
144
+ self.pw2 = nn.Conv2d(channels * 4, channels, 1)
145
+ self.act = nn.GELU()
146
+ self.relay = ConstellationRelay(
147
+ channels,
148
+ patch_dim=min(relay_patch_dim, channels),
149
+ n_anchors=min(relay_n_anchors, channels),
150
+ n_phases=relay_n_phases,
151
+ pw_hidden=relay_pw_hidden,
152
+ gate_init=relay_gate_init,
153
+ mode=relay_mode) if use_relay else None
154
+
155
+ def forward(self, x, cond):
156
+ residual = x
157
+ x = self.dw_conv(x)
158
+ x = self.norm(x, cond)
159
+ x = self.pw1(x)
160
+ x = self.act(x)
161
+ x = self.pw2(x)
162
+ x = residual + x
163
+ if self.relay is not None:
164
+ x = self.relay(x)
165
+ return x
166
+
167
+
168
+ class SelfAttnBlock(nn.Module):
169
+ def __init__(self, channels, n_heads=4):
170
+ super().__init__()
171
+ self.n_heads = n_heads
172
+ self.head_dim = channels // n_heads
173
+ self.norm = nn.GroupNorm(8, channels)
174
+ self.qkv = nn.Conv2d(channels, channels * 3, 1)
175
+ self.out = nn.Conv2d(channels, channels, 1)
176
+ nn.init.zeros_(self.out.weight)
177
+ nn.init.zeros_(self.out.bias)
178
+
179
+ def forward(self, x):
180
+ B, C, H, W = x.shape
181
+ residual = x
182
+ x = self.norm(x)
183
+ qkv = self.qkv(x).reshape(B, 3, self.n_heads, self.head_dim, H * W)
184
+ q, k, v = qkv[:, 0], qkv[:, 1], qkv[:, 2]
185
+ attn = F.scaled_dot_product_attention(q, k, v)
186
+ out = attn.reshape(B, C, H, W)
187
+ return residual + self.out(out)
188
+
189
+
190
+ class Downsample(nn.Module):
191
+ def __init__(self, channels):
192
+ super().__init__()
193
+ self.conv = nn.Conv2d(channels, channels, 3, stride=2, padding=1)
194
+
195
+ def forward(self, x):
196
+ return self.conv(x)
197
+
198
+
199
+ class Upsample(nn.Module):
200
+ def __init__(self, channels):
201
+ super().__init__()
202
+ self.conv = nn.Conv2d(channels, channels, 3, padding=1)
203
+
204
+ def forward(self, x):
205
+ x = F.interpolate(x, scale_factor=2, mode='nearest')
206
+ return self.conv(x)
207
+
208
+
209
+ # ══════════════════════════════════════════════════════════════════
210
+ # FLOW MATCHING UNET
211
+ # ══════════════════════════════════════════════════════════════════
212
+
213
+ class FlowMatchUNet(nn.Module):
214
+ def __init__(self, config):
215
+ super().__init__()
216
+ in_channels = config.in_channels
217
+ base_channels = config.base_channels
218
+ channel_mults = config.channel_mults
219
+ n_classes = config.n_classes
220
+ cond_dim = config.cond_dim
221
+ use_relay = config.use_relay
222
+ self.channel_mults = channel_mults
223
+
224
+ # Relay kwargs
225
+ rk = dict(
226
+ relay_patch_dim=config.relay_patch_dim,
227
+ relay_n_anchors=config.relay_n_anchors,
228
+ relay_n_phases=config.relay_n_phases,
229
+ relay_pw_hidden=config.relay_pw_hidden,
230
+ relay_gate_init=config.relay_gate_init,
231
+ relay_mode=config.relay_mode,
232
+ )
233
+
234
+ self.time_emb = nn.Sequential(
235
+ SinusoidalPosEmb(cond_dim),
236
+ nn.Linear(cond_dim, cond_dim), nn.GELU(),
237
+ nn.Linear(cond_dim, cond_dim))
238
+ self.class_emb = nn.Embedding(n_classes, cond_dim)
239
+ self.in_conv = nn.Conv2d(in_channels, base_channels, 3, padding=1)
240
+
241
+ # Encoder
242
+ self.enc = nn.ModuleList()
243
+ self.enc_down = nn.ModuleList()
244
+ ch_in = base_channels
245
+ enc_channels = [base_channels]
246
+
247
+ for i, mult in enumerate(channel_mults):
248
+ ch_out = base_channels * mult
249
+ self.enc.append(nn.ModuleList([
250
+ ConvBlock(ch_in, cond_dim) if ch_in == ch_out
251
+ else nn.Sequential(nn.Conv2d(ch_in, ch_out, 1),
252
+ ConvBlock(ch_out, cond_dim)),
253
+ ConvBlock(ch_out, cond_dim),
254
+ ]))
255
+ ch_in = ch_out
256
+ enc_channels.append(ch_out)
257
+ if i < len(channel_mults) - 1:
258
+ self.enc_down.append(Downsample(ch_out))
259
+
260
+ # Middle
261
+ mid_ch = ch_in
262
+ self.mid_block1 = ConvBlock(mid_ch, cond_dim, use_relay=use_relay, **rk)
263
+ self.mid_attn = SelfAttnBlock(mid_ch, n_heads=4)
264
+ self.mid_block2 = ConvBlock(mid_ch, cond_dim, use_relay=use_relay, **rk)
265
+
266
+ # Decoder
267
+ self.dec_up = nn.ModuleList()
268
+ self.dec_skip_proj = nn.ModuleList()
269
+ self.dec = nn.ModuleList()
270
+
271
+ for i in range(len(channel_mults) - 1, -1, -1):
272
+ ch_out = base_channels * channel_mults[i]
273
+ skip_ch = enc_channels.pop()
274
+ self.dec_skip_proj.append(nn.Conv2d(ch_in + skip_ch, ch_out, 1))
275
+ self.dec.append(nn.ModuleList([
276
+ ConvBlock(ch_out, cond_dim),
277
+ ConvBlock(ch_out, cond_dim),
278
+ ]))
279
+ ch_in = ch_out
280
+ if i > 0:
281
+ self.dec_up.append(Upsample(ch_out))
282
+
283
+ self.out_norm = nn.GroupNorm(8, ch_in)
284
+ self.out_conv = nn.Conv2d(ch_in, in_channels, 3, padding=1)
285
+ nn.init.zeros_(self.out_conv.weight)
286
+ nn.init.zeros_(self.out_conv.bias)
287
+
288
+ def forward(self, x, t, class_labels):
289
+ cond = self.time_emb(t) + self.class_emb(class_labels)
290
+ h = self.in_conv(x)
291
+ skips = [h]
292
+
293
+ for i in range(len(self.channel_mults)):
294
+ for block in self.enc[i]:
295
+ if isinstance(block, ConvBlock):
296
+ h = block(h, cond)
297
+ elif isinstance(block, nn.Sequential):
298
+ h = block[0](h)
299
+ h = block[1](h, cond)
300
+ else:
301
+ h = block(h)
302
+ skips.append(h)
303
+ if i < len(self.enc_down):
304
+ h = self.enc_down[i](h)
305
+
306
+ h = self.mid_block1(h, cond)
307
+ h = self.mid_attn(h)
308
+ h = self.mid_block2(h, cond)
309
+
310
+ for i in range(len(self.channel_mults)):
311
+ skip = skips.pop()
312
+ if i > 0:
313
+ h = self.dec_up[i - 1](h)
314
+ h = torch.cat([h, skip], dim=1)
315
+ h = self.dec_skip_proj[i](h)
316
+ for block in self.dec[i]:
317
+ h = block(h, cond)
318
+
319
+ h = self.out_norm(h)
320
+ h = F.silu(h)
321
+ return self.out_conv(h)
322
+
323
+
324
+ # ══════════════════════════════════════════════════════════════════
325
+ # HUGGINGFACE PRETRAINED MODEL WRAPPER
326
+ # ══════════════════════════════════════════════════════════════════
327
+
328
+ class FlowMatchRelayModel(PreTrainedModel):
329
+ """
330
+ HuggingFace-compatible wrapper for flow matching with constellation relay.
331
+
332
+ Load:
333
+ model = AutoModel.from_pretrained(
334
+ "AbstractPhil/geolip-diffusion-proto", trust_remote_code=True)
335
+
336
+ Generate:
337
+ images = model.sample(n_samples=8, class_label=3)
338
+ """
339
+ config_class = FlowMatchRelayConfig
340
+
341
+ def __init__(self, config):
342
+ super().__init__(config)
343
+ self.unet = FlowMatchUNet(config)
344
+
345
+ def forward(self, x, t, class_labels):
346
+ """
347
+ Predict velocity field for flow matching.
348
+
349
+ Args:
350
+ x: (B, 3, H, W) noisy images
351
+ t: (B,) timesteps in [0, 1]
352
+ class_labels: (B,) integer class labels
353
+
354
+ Returns:
355
+ v_pred: (B, 3, H, W) predicted velocity
356
+ """
357
+ return self.unet(x, t, class_labels)
358
+
359
+ @torch.no_grad()
360
+ def sample(self, n_samples=8, n_steps=None, class_label=None, device=None):
361
+ """
362
+ Generate images via Euler ODE integration.
363
+
364
+ Args:
365
+ n_samples: number of images to generate
366
+ n_steps: ODE integration steps (default from config)
367
+ class_label: optional class conditioning (0-9 for CIFAR-10)
368
+ device: target device
369
+
370
+ Returns:
371
+ images: (n_samples, 3, 32, 32) in [0, 1]
372
+ """
373
+ if device is None:
374
+ device = next(self.parameters()).device
375
+ if n_steps is None:
376
+ n_steps = self.config.n_sample_steps
377
+
378
+ self.eval()
379
+ x = torch.randn(n_samples, self.config.in_channels,
380
+ self.config.image_size, self.config.image_size,
381
+ device=device)
382
+
383
+ if class_label is not None:
384
+ labels = torch.full((n_samples,), class_label,
385
+ dtype=torch.long, device=device)
386
+ else:
387
+ labels = torch.randint(0, self.config.n_classes,
388
+ (n_samples,), device=device)
389
+
390
+ dt = 1.0 / n_steps
391
+ for step in range(n_steps):
392
+ t_val = 1.0 - step * dt
393
+ t = torch.full((n_samples,), t_val, device=device)
394
+ v = self.unet(x, t, labels)
395
+ x = x - v * dt
396
+
397
+ # [-1, 1] β†’ [0, 1]
398
+ return (x.clamp(-1, 1) + 1) / 2
399
+
400
+ def get_relay_diagnostics(self):
401
+ """Report constellation relay drift and gate values."""
402
+ diagnostics = {}
403
+ for name, module in self.named_modules():
404
+ if isinstance(module, ConstellationRelay):
405
+ drift = module.drift().mean().item()
406
+ gate = module.gates.sigmoid().mean().item()
407
+ diagnostics[name] = {
408
+ 'drift_rad': drift,
409
+ 'drift_deg': math.degrees(drift),
410
+ 'gate': gate,
411
+ }
412
+ return diagnostics