| """ |
| FlowMatchRelay model β HuggingFace compatible. |
| |
| Usage: |
| from transformers import AutoModel |
| model = AutoModel.from_pretrained( |
| "AbstractPhil/geolip-diffusion-proto", |
| trust_remote_code=True |
| ) |
| |
| # Generate samples |
| samples = model.sample(n_samples=8, class_label=3) # 8 cats |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import math |
| from transformers import PreTrainedModel |
| from .configuration_flow_match import FlowMatchRelayConfig |
|
|
|
|
| |
| |
| |
|
|
| class ConstellationRelay(nn.Module): |
| """ |
| Geometric regulator for feature maps. |
| Fixed anchors on S^(d-1), multi-phase stroboscope triangulation, |
| gated residual correction. |
| """ |
| def __init__(self, channels, patch_dim=16, n_anchors=16, n_phases=3, |
| pw_hidden=32, gate_init=-3.0, mode='channel'): |
| super().__init__() |
| assert channels % patch_dim == 0 |
| self.channels = channels |
| self.patch_dim = patch_dim |
| self.n_patches = channels // patch_dim |
| self.n_anchors = n_anchors |
| self.n_phases = n_phases |
| self.mode = mode |
|
|
| P, A, d = self.n_patches, n_anchors, patch_dim |
|
|
| home = torch.empty(P, A, d) |
| nn.init.xavier_normal_(home.view(P * A, d)) |
| home = F.normalize(home.view(P, A, d), dim=-1) |
| self.register_buffer('home', home) |
| self.anchors = nn.Parameter(home.clone()) |
|
|
| tri_dim = n_phases * A |
| self.pw_w1 = nn.Parameter(torch.empty(P, tri_dim, pw_hidden)) |
| self.pw_b1 = nn.Parameter(torch.zeros(1, P, pw_hidden)) |
| self.pw_w2 = nn.Parameter(torch.empty(P, pw_hidden, d)) |
| self.pw_b2 = nn.Parameter(torch.zeros(1, P, d)) |
| for p in range(P): |
| nn.init.xavier_normal_(self.pw_w1.data[p]) |
| nn.init.xavier_normal_(self.pw_w2.data[p]) |
| self.pw_norm = nn.LayerNorm(d) |
| self.gates = nn.Parameter(torch.full((P,), gate_init)) |
| self.norm = nn.LayerNorm(channels) |
|
|
| def drift(self): |
| h, c = F.normalize(self.home, dim=-1), F.normalize(self.anchors, dim=-1) |
| return torch.acos((h * c).sum(-1).clamp(-1 + 1e-7, 1 - 1e-7)) |
|
|
| def at_phase(self, t): |
| h, c = F.normalize(self.home, dim=-1), F.normalize(self.anchors, dim=-1) |
| omega = self.drift().unsqueeze(-1) |
| so = omega.sin().clamp(min=1e-7) |
| return torch.sin((1-t)*omega)/so * h + torch.sin(t*omega)/so * c |
|
|
| def _relay_core(self, x_flat): |
| N, C = x_flat.shape |
| P, A, d = self.n_patches, self.n_anchors, self.patch_dim |
| x_n = self.norm(x_flat) |
| patches = x_n.reshape(N, P, d) |
| patches_n = F.normalize(patches, dim=-1) |
| phases = torch.linspace(0, 1, self.n_phases, device=x_flat.device).tolist() |
| tris = [] |
| for t in phases: |
| at = F.normalize(self.at_phase(t), dim=-1) |
| tris.append(1.0 - torch.einsum('npd,pad->npa', patches_n, at)) |
| tri = torch.cat(tris, dim=-1) |
| h = F.gelu(torch.einsum('npt,pth->nph', tri, self.pw_w1) + self.pw_b1) |
| pw = self.pw_norm(torch.einsum('nph,phd->npd', h, self.pw_w2) + self.pw_b2) |
| g = self.gates.sigmoid().unsqueeze(0).unsqueeze(-1) |
| blended = g * pw + (1-g) * patches |
| return x_flat + blended.reshape(N, C) |
|
|
| def forward(self, x): |
| B, C, H, W = x.shape |
| if self.mode == 'channel': |
| pooled = x.mean(dim=(-2, -1)) |
| relayed = self._relay_core(pooled) |
| scale = (relayed / (pooled + 1e-8)).unsqueeze(-1).unsqueeze(-1) |
| return x * scale.clamp(-3, 3) |
| else: |
| x_flat = x.permute(0, 2, 3, 1).reshape(B * H * W, C) |
| out = self._relay_core(x_flat) |
| return out.reshape(B, H, W, C).permute(0, 3, 1, 2) |
|
|
|
|
| |
| |
| |
|
|
| class SinusoidalPosEmb(nn.Module): |
| def __init__(self, dim): |
| super().__init__() |
| self.dim = dim |
|
|
| def forward(self, t): |
| half = self.dim // 2 |
| emb = math.log(10000) / (half - 1) |
| emb = torch.exp(torch.arange(half, device=t.device, dtype=t.dtype) * -emb) |
| emb = t.unsqueeze(-1) * emb.unsqueeze(0) |
| return torch.cat([emb.sin(), emb.cos()], dim=-1) |
|
|
|
|
| class AdaGroupNorm(nn.Module): |
| def __init__(self, channels, cond_dim, n_groups=8): |
| super().__init__() |
| self.gn = nn.GroupNorm(min(n_groups, channels), channels, affine=False) |
| self.proj = nn.Linear(cond_dim, channels * 2) |
| nn.init.zeros_(self.proj.weight) |
| nn.init.zeros_(self.proj.bias) |
|
|
| def forward(self, x, cond): |
| x = self.gn(x) |
| scale, shift = self.proj(cond).unsqueeze(-1).unsqueeze(-1).chunk(2, dim=1) |
| return x * (1 + scale) + shift |
|
|
|
|
| class ConvBlock(nn.Module): |
| def __init__(self, channels, cond_dim, use_relay=False, |
| relay_patch_dim=16, relay_n_anchors=16, relay_n_phases=3, |
| relay_pw_hidden=32, relay_gate_init=-3.0, relay_mode='channel'): |
| super().__init__() |
| self.dw_conv = nn.Conv2d(channels, channels, 7, padding=3, groups=channels) |
| self.norm = AdaGroupNorm(channels, cond_dim) |
| self.pw1 = nn.Conv2d(channels, channels * 4, 1) |
| self.pw2 = nn.Conv2d(channels * 4, channels, 1) |
| self.act = nn.GELU() |
| self.relay = ConstellationRelay( |
| channels, |
| patch_dim=min(relay_patch_dim, channels), |
| n_anchors=min(relay_n_anchors, channels), |
| n_phases=relay_n_phases, |
| pw_hidden=relay_pw_hidden, |
| gate_init=relay_gate_init, |
| mode=relay_mode) if use_relay else None |
|
|
| def forward(self, x, cond): |
| residual = x |
| x = self.dw_conv(x) |
| x = self.norm(x, cond) |
| x = self.pw1(x) |
| x = self.act(x) |
| x = self.pw2(x) |
| x = residual + x |
| if self.relay is not None: |
| x = self.relay(x) |
| return x |
|
|
|
|
| class SelfAttnBlock(nn.Module): |
| def __init__(self, channels, n_heads=4): |
| super().__init__() |
| self.n_heads = n_heads |
| self.head_dim = channels // n_heads |
| self.norm = nn.GroupNorm(8, channels) |
| self.qkv = nn.Conv2d(channels, channels * 3, 1) |
| self.out = nn.Conv2d(channels, channels, 1) |
| nn.init.zeros_(self.out.weight) |
| nn.init.zeros_(self.out.bias) |
|
|
| def forward(self, x): |
| B, C, H, W = x.shape |
| residual = x |
| x = self.norm(x) |
| qkv = self.qkv(x).reshape(B, 3, self.n_heads, self.head_dim, H * W) |
| q, k, v = qkv[:, 0], qkv[:, 1], qkv[:, 2] |
| attn = F.scaled_dot_product_attention(q, k, v) |
| out = attn.reshape(B, C, H, W) |
| return residual + self.out(out) |
|
|
|
|
| class Downsample(nn.Module): |
| def __init__(self, channels): |
| super().__init__() |
| self.conv = nn.Conv2d(channels, channels, 3, stride=2, padding=1) |
|
|
| def forward(self, x): |
| return self.conv(x) |
|
|
|
|
| class Upsample(nn.Module): |
| def __init__(self, channels): |
| super().__init__() |
| self.conv = nn.Conv2d(channels, channels, 3, padding=1) |
|
|
| def forward(self, x): |
| x = F.interpolate(x, scale_factor=2, mode='nearest') |
| return self.conv(x) |
|
|
|
|
| |
| |
| |
|
|
| class FlowMatchUNet(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| in_channels = config.in_channels |
| base_channels = config.base_channels |
| channel_mults = config.channel_mults |
| n_classes = config.n_classes |
| cond_dim = config.cond_dim |
| use_relay = config.use_relay |
| self.channel_mults = channel_mults |
|
|
| |
| rk = dict( |
| relay_patch_dim=config.relay_patch_dim, |
| relay_n_anchors=config.relay_n_anchors, |
| relay_n_phases=config.relay_n_phases, |
| relay_pw_hidden=config.relay_pw_hidden, |
| relay_gate_init=config.relay_gate_init, |
| relay_mode=config.relay_mode, |
| ) |
|
|
| self.time_emb = nn.Sequential( |
| SinusoidalPosEmb(cond_dim), |
| nn.Linear(cond_dim, cond_dim), nn.GELU(), |
| nn.Linear(cond_dim, cond_dim)) |
| self.class_emb = nn.Embedding(n_classes, cond_dim) |
| self.in_conv = nn.Conv2d(in_channels, base_channels, 3, padding=1) |
|
|
| |
| self.enc = nn.ModuleList() |
| self.enc_down = nn.ModuleList() |
| ch_in = base_channels |
| enc_channels = [base_channels] |
|
|
| for i, mult in enumerate(channel_mults): |
| ch_out = base_channels * mult |
| self.enc.append(nn.ModuleList([ |
| ConvBlock(ch_in, cond_dim) if ch_in == ch_out |
| else nn.Sequential(nn.Conv2d(ch_in, ch_out, 1), |
| ConvBlock(ch_out, cond_dim)), |
| ConvBlock(ch_out, cond_dim), |
| ])) |
| ch_in = ch_out |
| enc_channels.append(ch_out) |
| if i < len(channel_mults) - 1: |
| self.enc_down.append(Downsample(ch_out)) |
|
|
| |
| mid_ch = ch_in |
| self.mid_block1 = ConvBlock(mid_ch, cond_dim, use_relay=use_relay, **rk) |
| self.mid_attn = SelfAttnBlock(mid_ch, n_heads=4) |
| self.mid_block2 = ConvBlock(mid_ch, cond_dim, use_relay=use_relay, **rk) |
|
|
| |
| self.dec_up = nn.ModuleList() |
| self.dec_skip_proj = nn.ModuleList() |
| self.dec = nn.ModuleList() |
|
|
| for i in range(len(channel_mults) - 1, -1, -1): |
| ch_out = base_channels * channel_mults[i] |
| skip_ch = enc_channels.pop() |
| self.dec_skip_proj.append(nn.Conv2d(ch_in + skip_ch, ch_out, 1)) |
| self.dec.append(nn.ModuleList([ |
| ConvBlock(ch_out, cond_dim), |
| ConvBlock(ch_out, cond_dim), |
| ])) |
| ch_in = ch_out |
| if i > 0: |
| self.dec_up.append(Upsample(ch_out)) |
|
|
| self.out_norm = nn.GroupNorm(8, ch_in) |
| self.out_conv = nn.Conv2d(ch_in, in_channels, 3, padding=1) |
| nn.init.zeros_(self.out_conv.weight) |
| nn.init.zeros_(self.out_conv.bias) |
|
|
| def forward(self, x, t, class_labels): |
| cond = self.time_emb(t) + self.class_emb(class_labels) |
| h = self.in_conv(x) |
| skips = [h] |
|
|
| for i in range(len(self.channel_mults)): |
| for block in self.enc[i]: |
| if isinstance(block, ConvBlock): |
| h = block(h, cond) |
| elif isinstance(block, nn.Sequential): |
| h = block[0](h) |
| h = block[1](h, cond) |
| else: |
| h = block(h) |
| skips.append(h) |
| if i < len(self.enc_down): |
| h = self.enc_down[i](h) |
|
|
| h = self.mid_block1(h, cond) |
| h = self.mid_attn(h) |
| h = self.mid_block2(h, cond) |
|
|
| for i in range(len(self.channel_mults)): |
| skip = skips.pop() |
| if i > 0: |
| h = self.dec_up[i - 1](h) |
| h = torch.cat([h, skip], dim=1) |
| h = self.dec_skip_proj[i](h) |
| for block in self.dec[i]: |
| h = block(h, cond) |
|
|
| h = self.out_norm(h) |
| h = F.silu(h) |
| return self.out_conv(h) |
|
|
|
|
| |
| |
| |
|
|
| class FlowMatchRelayModel(PreTrainedModel): |
| """ |
| HuggingFace-compatible wrapper for flow matching with constellation relay. |
| |
| Load: |
| model = AutoModel.from_pretrained( |
| "AbstractPhil/geolip-diffusion-proto", trust_remote_code=True) |
| |
| Generate: |
| images = model.sample(n_samples=8, class_label=3) |
| """ |
| config_class = FlowMatchRelayConfig |
| _tied_weights_keys = [] |
| _keys_to_ignore_on_load_missing = [] |
| _keys_to_ignore_on_load_unexpected = [] |
| _no_split_modules = [] |
| supports_gradient_checkpointing = False |
|
|
| def __init__(self, config): |
| super().__init__(config) |
| self.unet = FlowMatchUNet(config) |
| self.post_init() |
|
|
| def _init_weights(self, module): |
| """No-op β weights loaded from checkpoint or already initialized.""" |
| pass |
|
|
| def forward(self, x, t, class_labels): |
| """ |
| Predict velocity field for flow matching. |
| |
| Args: |
| x: (B, 3, H, W) noisy images |
| t: (B,) timesteps in [0, 1] |
| class_labels: (B,) integer class labels |
| |
| Returns: |
| v_pred: (B, 3, H, W) predicted velocity |
| """ |
| return self.unet(x, t, class_labels) |
|
|
| @torch.no_grad() |
| def sample(self, n_samples=8, n_steps=None, class_label=None, device=None): |
| """ |
| Generate images via Euler ODE integration. |
| |
| Args: |
| n_samples: number of images to generate |
| n_steps: ODE integration steps (default from config) |
| class_label: optional class conditioning (0-9 for CIFAR-10) |
| device: target device |
| |
| Returns: |
| images: (n_samples, 3, 32, 32) in [0, 1] |
| """ |
| if device is None: |
| device = next(self.parameters()).device |
| if n_steps is None: |
| n_steps = self.config.n_sample_steps |
|
|
| self.eval() |
| x = torch.randn(n_samples, self.config.in_channels, |
| self.config.image_size, self.config.image_size, |
| device=device) |
|
|
| if class_label is not None: |
| labels = torch.full((n_samples,), class_label, |
| dtype=torch.long, device=device) |
| else: |
| labels = torch.randint(0, self.config.n_classes, |
| (n_samples,), device=device) |
|
|
| dt = 1.0 / n_steps |
| for step in range(n_steps): |
| t_val = 1.0 - step * dt |
| t = torch.full((n_samples,), t_val, device=device) |
| v = self.unet(x, t, labels) |
| x = x - v * dt |
|
|
| |
| return (x.clamp(-1, 1) + 1) / 2 |
|
|
| def get_relay_diagnostics(self): |
| """Report constellation relay drift and gate values.""" |
| diagnostics = {} |
| for name, module in self.named_modules(): |
| if isinstance(module, ConstellationRelay): |
| drift = module.drift().mean().item() |
| gate = module.gates.sigmoid().mean().item() |
| diagnostics[name] = { |
| 'drift_rad': drift, |
| 'drift_deg': math.degrees(drift), |
| 'gate': gate, |
| } |
| return diagnostics |