""" FlowMatchRelay model — HuggingFace compatible. Usage: from transformers import AutoModel model = AutoModel.from_pretrained( "AbstractPhil/geolip-diffusion-proto", trust_remote_code=True ) # Generate samples samples = model.sample(n_samples=8, class_label=3) # 8 cats """ import torch import torch.nn as nn import torch.nn.functional as F import math from transformers import PreTrainedModel from .configuration_flow_match import FlowMatchRelayConfig # ══════════════════════════════════════════════════════════════════ # CONSTELLATION RELAY # ══════════════════════════════════════════════════════════════════ class ConstellationRelay(nn.Module): """ Geometric regulator for feature maps. Fixed anchors on S^(d-1), multi-phase stroboscope triangulation, gated residual correction. """ def __init__(self, channels, patch_dim=16, n_anchors=16, n_phases=3, pw_hidden=32, gate_init=-3.0, mode='channel'): super().__init__() assert channels % patch_dim == 0 self.channels = channels self.patch_dim = patch_dim self.n_patches = channels // patch_dim self.n_anchors = n_anchors self.n_phases = n_phases self.mode = mode P, A, d = self.n_patches, n_anchors, patch_dim home = torch.empty(P, A, d) nn.init.xavier_normal_(home.view(P * A, d)) home = F.normalize(home.view(P, A, d), dim=-1) self.register_buffer('home', home) self.anchors = nn.Parameter(home.clone()) tri_dim = n_phases * A self.pw_w1 = nn.Parameter(torch.empty(P, tri_dim, pw_hidden)) self.pw_b1 = nn.Parameter(torch.zeros(1, P, pw_hidden)) self.pw_w2 = nn.Parameter(torch.empty(P, pw_hidden, d)) self.pw_b2 = nn.Parameter(torch.zeros(1, P, d)) for p in range(P): nn.init.xavier_normal_(self.pw_w1.data[p]) nn.init.xavier_normal_(self.pw_w2.data[p]) self.pw_norm = nn.LayerNorm(d) self.gates = nn.Parameter(torch.full((P,), gate_init)) self.norm = nn.LayerNorm(channels) def drift(self): h, c = F.normalize(self.home, dim=-1), F.normalize(self.anchors, dim=-1) return torch.acos((h * c).sum(-1).clamp(-1 + 1e-7, 1 - 1e-7)) def at_phase(self, t): h, c = F.normalize(self.home, dim=-1), F.normalize(self.anchors, dim=-1) omega = self.drift().unsqueeze(-1) so = omega.sin().clamp(min=1e-7) return torch.sin((1-t)*omega)/so * h + torch.sin(t*omega)/so * c def _relay_core(self, x_flat): N, C = x_flat.shape P, A, d = self.n_patches, self.n_anchors, self.patch_dim x_n = self.norm(x_flat) patches = x_n.reshape(N, P, d) patches_n = F.normalize(patches, dim=-1) phases = torch.linspace(0, 1, self.n_phases, device=x_flat.device).tolist() tris = [] for t in phases: at = F.normalize(self.at_phase(t), dim=-1) tris.append(1.0 - torch.einsum('npd,pad->npa', patches_n, at)) tri = torch.cat(tris, dim=-1) h = F.gelu(torch.einsum('npt,pth->nph', tri, self.pw_w1) + self.pw_b1) pw = self.pw_norm(torch.einsum('nph,phd->npd', h, self.pw_w2) + self.pw_b2) g = self.gates.sigmoid().unsqueeze(0).unsqueeze(-1) blended = g * pw + (1-g) * patches return x_flat + blended.reshape(N, C) def forward(self, x): B, C, H, W = x.shape if self.mode == 'channel': pooled = x.mean(dim=(-2, -1)) relayed = self._relay_core(pooled) scale = (relayed / (pooled + 1e-8)).unsqueeze(-1).unsqueeze(-1) return x * scale.clamp(-3, 3) else: x_flat = x.permute(0, 2, 3, 1).reshape(B * H * W, C) out = self._relay_core(x_flat) return out.reshape(B, H, W, C).permute(0, 3, 1, 2) # ══════════════════════════════════════════════════════════════════ # BUILDING BLOCKS # ══════════════════════════════════════════════════════════════════ class SinusoidalPosEmb(nn.Module): def __init__(self, dim): super().__init__() self.dim = dim def forward(self, t): half = self.dim // 2 emb = math.log(10000) / (half - 1) emb = torch.exp(torch.arange(half, device=t.device, dtype=t.dtype) * -emb) emb = t.unsqueeze(-1) * emb.unsqueeze(0) return torch.cat([emb.sin(), emb.cos()], dim=-1) class AdaGroupNorm(nn.Module): def __init__(self, channels, cond_dim, n_groups=8): super().__init__() self.gn = nn.GroupNorm(min(n_groups, channels), channels, affine=False) self.proj = nn.Linear(cond_dim, channels * 2) nn.init.zeros_(self.proj.weight) nn.init.zeros_(self.proj.bias) def forward(self, x, cond): x = self.gn(x) scale, shift = self.proj(cond).unsqueeze(-1).unsqueeze(-1).chunk(2, dim=1) return x * (1 + scale) + shift class ConvBlock(nn.Module): def __init__(self, channels, cond_dim, use_relay=False, relay_patch_dim=16, relay_n_anchors=16, relay_n_phases=3, relay_pw_hidden=32, relay_gate_init=-3.0, relay_mode='channel'): super().__init__() self.dw_conv = nn.Conv2d(channels, channels, 7, padding=3, groups=channels) self.norm = AdaGroupNorm(channels, cond_dim) self.pw1 = nn.Conv2d(channels, channels * 4, 1) self.pw2 = nn.Conv2d(channels * 4, channels, 1) self.act = nn.GELU() self.relay = ConstellationRelay( channels, patch_dim=min(relay_patch_dim, channels), n_anchors=min(relay_n_anchors, channels), n_phases=relay_n_phases, pw_hidden=relay_pw_hidden, gate_init=relay_gate_init, mode=relay_mode) if use_relay else None def forward(self, x, cond): residual = x x = self.dw_conv(x) x = self.norm(x, cond) x = self.pw1(x) x = self.act(x) x = self.pw2(x) x = residual + x if self.relay is not None: x = self.relay(x) return x class SelfAttnBlock(nn.Module): def __init__(self, channels, n_heads=4): super().__init__() self.n_heads = n_heads self.head_dim = channels // n_heads self.norm = nn.GroupNorm(8, channels) self.qkv = nn.Conv2d(channels, channels * 3, 1) self.out = nn.Conv2d(channels, channels, 1) nn.init.zeros_(self.out.weight) nn.init.zeros_(self.out.bias) def forward(self, x): B, C, H, W = x.shape residual = x x = self.norm(x) qkv = self.qkv(x).reshape(B, 3, self.n_heads, self.head_dim, H * W) q, k, v = qkv[:, 0], qkv[:, 1], qkv[:, 2] attn = F.scaled_dot_product_attention(q, k, v) out = attn.reshape(B, C, H, W) return residual + self.out(out) class Downsample(nn.Module): def __init__(self, channels): super().__init__() self.conv = nn.Conv2d(channels, channels, 3, stride=2, padding=1) def forward(self, x): return self.conv(x) class Upsample(nn.Module): def __init__(self, channels): super().__init__() self.conv = nn.Conv2d(channels, channels, 3, padding=1) def forward(self, x): x = F.interpolate(x, scale_factor=2, mode='nearest') return self.conv(x) # ══════════════════════════════════════════════════════════════════ # FLOW MATCHING UNET # ══════════════════════════════════════════════════════════════════ class FlowMatchUNet(nn.Module): def __init__(self, config): super().__init__() in_channels = config.in_channels base_channels = config.base_channels channel_mults = config.channel_mults n_classes = config.n_classes cond_dim = config.cond_dim use_relay = config.use_relay self.channel_mults = channel_mults # Relay kwargs rk = dict( relay_patch_dim=config.relay_patch_dim, relay_n_anchors=config.relay_n_anchors, relay_n_phases=config.relay_n_phases, relay_pw_hidden=config.relay_pw_hidden, relay_gate_init=config.relay_gate_init, relay_mode=config.relay_mode, ) self.time_emb = nn.Sequential( SinusoidalPosEmb(cond_dim), nn.Linear(cond_dim, cond_dim), nn.GELU(), nn.Linear(cond_dim, cond_dim)) self.class_emb = nn.Embedding(n_classes, cond_dim) self.in_conv = nn.Conv2d(in_channels, base_channels, 3, padding=1) # Encoder self.enc = nn.ModuleList() self.enc_down = nn.ModuleList() ch_in = base_channels enc_channels = [base_channels] for i, mult in enumerate(channel_mults): ch_out = base_channels * mult self.enc.append(nn.ModuleList([ ConvBlock(ch_in, cond_dim) if ch_in == ch_out else nn.Sequential(nn.Conv2d(ch_in, ch_out, 1), ConvBlock(ch_out, cond_dim)), ConvBlock(ch_out, cond_dim), ])) ch_in = ch_out enc_channels.append(ch_out) if i < len(channel_mults) - 1: self.enc_down.append(Downsample(ch_out)) # Middle mid_ch = ch_in self.mid_block1 = ConvBlock(mid_ch, cond_dim, use_relay=use_relay, **rk) self.mid_attn = SelfAttnBlock(mid_ch, n_heads=4) self.mid_block2 = ConvBlock(mid_ch, cond_dim, use_relay=use_relay, **rk) # Decoder self.dec_up = nn.ModuleList() self.dec_skip_proj = nn.ModuleList() self.dec = nn.ModuleList() for i in range(len(channel_mults) - 1, -1, -1): ch_out = base_channels * channel_mults[i] skip_ch = enc_channels.pop() self.dec_skip_proj.append(nn.Conv2d(ch_in + skip_ch, ch_out, 1)) self.dec.append(nn.ModuleList([ ConvBlock(ch_out, cond_dim), ConvBlock(ch_out, cond_dim), ])) ch_in = ch_out if i > 0: self.dec_up.append(Upsample(ch_out)) self.out_norm = nn.GroupNorm(8, ch_in) self.out_conv = nn.Conv2d(ch_in, in_channels, 3, padding=1) nn.init.zeros_(self.out_conv.weight) nn.init.zeros_(self.out_conv.bias) def forward(self, x, t, class_labels): cond = self.time_emb(t) + self.class_emb(class_labels) h = self.in_conv(x) skips = [h] for i in range(len(self.channel_mults)): for block in self.enc[i]: if isinstance(block, ConvBlock): h = block(h, cond) elif isinstance(block, nn.Sequential): h = block[0](h) h = block[1](h, cond) else: h = block(h) skips.append(h) if i < len(self.enc_down): h = self.enc_down[i](h) h = self.mid_block1(h, cond) h = self.mid_attn(h) h = self.mid_block2(h, cond) for i in range(len(self.channel_mults)): skip = skips.pop() if i > 0: h = self.dec_up[i - 1](h) h = torch.cat([h, skip], dim=1) h = self.dec_skip_proj[i](h) for block in self.dec[i]: h = block(h, cond) h = self.out_norm(h) h = F.silu(h) return self.out_conv(h) # ══════════════════════════════════════════════════════════════════ # HUGGINGFACE PRETRAINED MODEL WRAPPER # ══════════════════════════════════════════════════════════════════ class FlowMatchRelayModel(PreTrainedModel): """ HuggingFace-compatible wrapper for flow matching with constellation relay. Load: model = AutoModel.from_pretrained( "AbstractPhil/geolip-diffusion-proto", trust_remote_code=True) Generate: images = model.sample(n_samples=8, class_label=3) """ config_class = FlowMatchRelayConfig _tied_weights_keys = [] _keys_to_ignore_on_load_missing = [] _keys_to_ignore_on_load_unexpected = [] _no_split_modules = [] supports_gradient_checkpointing = False def __init__(self, config): super().__init__(config) self.unet = FlowMatchUNet(config) self.post_init() def _init_weights(self, module): """No-op — weights loaded from checkpoint or already initialized.""" pass def forward(self, x, t, class_labels): """ Predict velocity field for flow matching. Args: x: (B, 3, H, W) noisy images t: (B,) timesteps in [0, 1] class_labels: (B,) integer class labels Returns: v_pred: (B, 3, H, W) predicted velocity """ return self.unet(x, t, class_labels) @torch.no_grad() def sample(self, n_samples=8, n_steps=None, class_label=None, device=None): """ Generate images via Euler ODE integration. Args: n_samples: number of images to generate n_steps: ODE integration steps (default from config) class_label: optional class conditioning (0-9 for CIFAR-10) device: target device Returns: images: (n_samples, 3, 32, 32) in [0, 1] """ if device is None: device = next(self.parameters()).device if n_steps is None: n_steps = self.config.n_sample_steps self.eval() x = torch.randn(n_samples, self.config.in_channels, self.config.image_size, self.config.image_size, device=device) if class_label is not None: labels = torch.full((n_samples,), class_label, dtype=torch.long, device=device) else: labels = torch.randint(0, self.config.n_classes, (n_samples,), device=device) dt = 1.0 / n_steps for step in range(n_steps): t_val = 1.0 - step * dt t = torch.full((n_samples,), t_val, device=device) v = self.unet(x, t, labels) x = x - v * dt # [-1, 1] → [0, 1] return (x.clamp(-1, 1) + 1) / 2 def get_relay_diagnostics(self): """Report constellation relay drift and gate values.""" diagnostics = {} for name, module in self.named_modules(): if isinstance(module, ConstellationRelay): drift = module.drift().mean().item() gate = module.gates.sigmoid().mean().item() diagnostics[name] = { 'drift_rad': drift, 'drift_deg': math.degrees(drift), 'gate': gate, } return diagnostics