""" PrivacyGraphTransformer: Graphormer-style transformer operating on AMR nodes. No raw text positions used — uses structural encodings (centrality, spatial, edge). """ import math import torch import torch.nn as nn import torch.nn.functional as F from typing import Optional, Tuple class CentralityEncoding(nn.Module): """Learnable degree-based centrality embedding (Graphormer).""" def __init__(self, hidden_dim: int, max_degree: int = 512): super().__init__() self.hidden_dim = hidden_dim self.max_degree = max_degree self.degree_embed = nn.Embedding(max_degree + 1, hidden_dim) nn.init.normal_(self.degree_embed.weight, std=0.02) def forward(self, degree: torch.Tensor) -> torch.Tensor: degree = degree.clamp(0, self.max_degree) return self.degree_embed(degree) class SpatialEncoding(nn.Module): """Learnable shortest-path-distance bias in attention (Graphormer).""" def __init__(self, num_heads: int, max_distance: int = 128): super().__init__() self.num_heads = num_heads self.max_distance = max_distance self.spatial_bias = nn.Embedding(max_distance + 2, num_heads) nn.init.zeros_(self.spatial_bias.weight) def forward(self, spd: torch.Tensor) -> torch.Tensor: spd = spd.clamp(-1, self.max_distance + 1) spd = torch.where(spd < 0, self.max_distance + 1, spd) bias = self.spatial_bias(spd) return bias.permute(0, 3, 1, 2) class EdgeEncoding(nn.Module): """Edge feature attention bias (Graphormer).""" def __init__(self, num_heads: int, max_edge_features: int = 32): super().__init__() self.num_heads = num_heads self.edge_embed = nn.Embedding(max_edge_features + 1, num_heads) nn.init.zeros_(self.edge_embed.weight) def forward(self, edge_index, edge_types, spd_paths=None): if spd_paths is None: return 0.0 return 0.0 class PrivacyAwareSelfAttention(nn.Module): """Multi-head self-attention with structural bias injection.""" def __init__(self, config): super().__init__() self.hidden_dim = config.hidden_dim self.num_heads = config.num_heads self.head_dim = config.hidden_dim // config.num_heads self.scale = 1.0 / math.sqrt(self.head_dim) self.q_proj = nn.Linear(config.hidden_dim, config.hidden_dim) self.k_proj = nn.Linear(config.hidden_dim, config.hidden_dim) self.v_proj = nn.Linear(config.hidden_dim, config.hidden_dim) self.out_proj = nn.Linear(config.hidden_dim, config.hidden_dim) self.dropout = nn.Dropout(config.attention_dropout) self.use_spatial = config.use_spatial_encoding self.use_edge = config.use_edge_encoding def forward(self, x, attn_mask=None, spatial_bias=None, edge_bias=None, key_padding_mask=None): B, N, D = x.shape q = self.q_proj(x).view(B, N, self.num_heads, self.head_dim).transpose(1, 2) k = self.k_proj(x).view(B, N, self.num_heads, self.head_dim).transpose(1, 2) v = self.v_proj(x).view(B, N, self.num_heads, self.head_dim).transpose(1, 2) attn_scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale if self.use_spatial and spatial_bias is not None: attn_scores = attn_scores + spatial_bias if self.use_edge and edge_bias is not None and edge_bias != 0.0: attn_scores = attn_scores + edge_bias if attn_mask is not None: attn_scores = attn_scores.masked_fill(attn_mask == 0, float("-inf")) if key_padding_mask is not None: mask = key_padding_mask.unsqueeze(1).unsqueeze(2) attn_scores = attn_scores.masked_fill(mask == 1, float("-inf")) attn_weights = F.softmax(attn_scores, dim=-1) attn_weights = self.dropout(attn_weights) out = torch.matmul(attn_weights, v) out = out.transpose(1, 2).contiguous().view(B, N, D) return self.out_proj(out) class PrivacyFeedForward(nn.Module): def __init__(self, hidden_dim: int, ffn_dim: int, dropout: float = 0.1): super().__init__() self.fc1 = nn.Linear(hidden_dim, ffn_dim) self.fc2 = nn.Linear(ffn_dim, hidden_dim) self.dropout = nn.Dropout(dropout) self.activation = nn.GELU() def forward(self, x): x = self.fc1(x) x = self.activation(x) x = self.dropout(x) x = self.fc2(x) return x class PrivacyAwareBlock(nn.Module): def __init__(self, config): super().__init__() self.attn = PrivacyAwareSelfAttention(config) self.ffn = PrivacyFeedForward(config.hidden_dim, config.ffn_dim, config.dropout) self.ln1 = nn.LayerNorm(config.hidden_dim) self.ln2 = nn.LayerNorm(config.hidden_dim) self.dropout = nn.Dropout(config.dropout) self.use_dp = config.use_dp_training self.clip_norm = config.dp_clip_norm def forward(self, x, attn_mask=None, spatial_bias=None, edge_bias=None, key_padding_mask=None): attn_out = self.attn( self.ln1(x), attn_mask=attn_mask, spatial_bias=spatial_bias, edge_bias=edge_bias, key_padding_mask=key_padding_mask ) x = x + self.dropout(attn_out) ffn_out = self.ffn(self.ln2(x)) x = x + self.dropout(ffn_out) return x class PrivacyGraphTransformer(nn.Module): def __init__(self, config): super().__init__() self.config = config self.hidden_dim = config.hidden_dim self.node_embed = nn.Linear(config.num_abstract_types + 16, config.hidden_dim) self.type_embed = nn.Embedding(config.num_abstract_types, config.hidden_dim) if config.use_centrality_encoding: self.centrality_embed = CentralityEncoding(config.hidden_dim, config.max_degree) if config.use_spatial_encoding: self.spatial_embed = SpatialEncoding(config.num_heads, config.max_spatial_dist) if config.use_edge_encoding: self.edge_embed = EdgeEncoding(config.num_heads, config.max_edge_features) self.layers = nn.ModuleList([ PrivacyAwareBlock(config) for _ in range(config.num_encoder_layers) ]) self.final_ln = nn.LayerNorm(config.hidden_dim) self.apply(self._init_weights) def _init_weights(self, module): if isinstance(module, nn.Linear): nn.init.normal_(module.weight, std=0.02) if module.bias is not None: nn.init.zeros_(module.bias) elif isinstance(module, nn.LayerNorm): nn.init.ones_(module.weight) nn.init.zeros_(module.bias) def forward(self, node_features, node_types=None, degree=None, spd=None, edge_index=None, edge_types=None, attention_mask=None, key_padding_mask=None): B, N, F = node_features.shape x = self.node_embed(node_features) if self.config.use_centrality_encoding and degree is not None: x = x + self.centrality_embed(degree) if node_types is not None: x = x + self.type_embed(node_types) spatial_bias = None if self.config.use_spatial_encoding and spd is not None: spatial_bias = self.spatial_embed(spd) edge_bias = None if self.config.use_edge_encoding and edge_index is not None: edge_bias = self.edge_embed(edge_index, edge_types) for layer in self.layers: x = layer(x, attn_mask=attention_mask, spatial_bias=spatial_bias, edge_bias=edge_bias, key_padding_mask=key_padding_mask) x = self.final_ln(x) return x