| """ |
| PrivacyGraphTransformer: Graphormer-style transformer operating on AMR nodes. |
| No raw text positions used — uses structural encodings (centrality, spatial, edge). |
| """ |
|
|
| import math |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from typing import Optional, Tuple |
|
|
|
|
| class CentralityEncoding(nn.Module): |
| """Learnable degree-based centrality embedding (Graphormer).""" |
|
|
| def __init__(self, hidden_dim: int, max_degree: int = 512): |
| super().__init__() |
| self.hidden_dim = hidden_dim |
| self.max_degree = max_degree |
| self.degree_embed = nn.Embedding(max_degree + 1, hidden_dim) |
| nn.init.normal_(self.degree_embed.weight, std=0.02) |
|
|
| def forward(self, degree: torch.Tensor) -> torch.Tensor: |
| degree = degree.clamp(0, self.max_degree) |
| return self.degree_embed(degree) |
|
|
|
|
| class SpatialEncoding(nn.Module): |
| """Learnable shortest-path-distance bias in attention (Graphormer).""" |
|
|
| def __init__(self, num_heads: int, max_distance: int = 128): |
| super().__init__() |
| self.num_heads = num_heads |
| self.max_distance = max_distance |
| self.spatial_bias = nn.Embedding(max_distance + 2, num_heads) |
| nn.init.zeros_(self.spatial_bias.weight) |
|
|
| def forward(self, spd: torch.Tensor) -> torch.Tensor: |
| spd = spd.clamp(-1, self.max_distance + 1) |
| spd = torch.where(spd < 0, self.max_distance + 1, spd) |
| bias = self.spatial_bias(spd) |
| return bias.permute(0, 3, 1, 2) |
|
|
|
|
| class EdgeEncoding(nn.Module): |
| """Edge feature attention bias (Graphormer).""" |
|
|
| def __init__(self, num_heads: int, max_edge_features: int = 32): |
| super().__init__() |
| self.num_heads = num_heads |
| self.edge_embed = nn.Embedding(max_edge_features + 1, num_heads) |
| nn.init.zeros_(self.edge_embed.weight) |
|
|
| def forward(self, edge_index, edge_types, spd_paths=None): |
| if spd_paths is None: |
| return 0.0 |
| return 0.0 |
|
|
|
|
| class PrivacyAwareSelfAttention(nn.Module): |
| """Multi-head self-attention with structural bias injection.""" |
|
|
| def __init__(self, config): |
| super().__init__() |
| self.hidden_dim = config.hidden_dim |
| self.num_heads = config.num_heads |
| self.head_dim = config.hidden_dim // config.num_heads |
| self.scale = 1.0 / math.sqrt(self.head_dim) |
|
|
| self.q_proj = nn.Linear(config.hidden_dim, config.hidden_dim) |
| self.k_proj = nn.Linear(config.hidden_dim, config.hidden_dim) |
| self.v_proj = nn.Linear(config.hidden_dim, config.hidden_dim) |
| self.out_proj = nn.Linear(config.hidden_dim, config.hidden_dim) |
| self.dropout = nn.Dropout(config.attention_dropout) |
| self.use_spatial = config.use_spatial_encoding |
| self.use_edge = config.use_edge_encoding |
|
|
| def forward(self, x, attn_mask=None, spatial_bias=None, edge_bias=None, |
| key_padding_mask=None): |
| B, N, D = x.shape |
| q = self.q_proj(x).view(B, N, self.num_heads, self.head_dim).transpose(1, 2) |
| k = self.k_proj(x).view(B, N, self.num_heads, self.head_dim).transpose(1, 2) |
| v = self.v_proj(x).view(B, N, self.num_heads, self.head_dim).transpose(1, 2) |
|
|
| attn_scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale |
|
|
| if self.use_spatial and spatial_bias is not None: |
| attn_scores = attn_scores + spatial_bias |
| if self.use_edge and edge_bias is not None and edge_bias != 0.0: |
| attn_scores = attn_scores + edge_bias |
|
|
| if attn_mask is not None: |
| attn_scores = attn_scores.masked_fill(attn_mask == 0, float("-inf")) |
| if key_padding_mask is not None: |
| mask = key_padding_mask.unsqueeze(1).unsqueeze(2) |
| attn_scores = attn_scores.masked_fill(mask == 1, float("-inf")) |
|
|
| attn_weights = F.softmax(attn_scores, dim=-1) |
| attn_weights = self.dropout(attn_weights) |
| out = torch.matmul(attn_weights, v) |
| out = out.transpose(1, 2).contiguous().view(B, N, D) |
| return self.out_proj(out) |
|
|
|
|
| class PrivacyFeedForward(nn.Module): |
| def __init__(self, hidden_dim: int, ffn_dim: int, dropout: float = 0.1): |
| super().__init__() |
| self.fc1 = nn.Linear(hidden_dim, ffn_dim) |
| self.fc2 = nn.Linear(ffn_dim, hidden_dim) |
| self.dropout = nn.Dropout(dropout) |
| self.activation = nn.GELU() |
|
|
| def forward(self, x): |
| x = self.fc1(x) |
| x = self.activation(x) |
| x = self.dropout(x) |
| x = self.fc2(x) |
| return x |
|
|
|
|
| class PrivacyAwareBlock(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.attn = PrivacyAwareSelfAttention(config) |
| self.ffn = PrivacyFeedForward(config.hidden_dim, config.ffn_dim, config.dropout) |
| self.ln1 = nn.LayerNorm(config.hidden_dim) |
| self.ln2 = nn.LayerNorm(config.hidden_dim) |
| self.dropout = nn.Dropout(config.dropout) |
| self.use_dp = config.use_dp_training |
| self.clip_norm = config.dp_clip_norm |
|
|
| def forward(self, x, attn_mask=None, spatial_bias=None, edge_bias=None, |
| key_padding_mask=None): |
| attn_out = self.attn( |
| self.ln1(x), attn_mask=attn_mask, |
| spatial_bias=spatial_bias, edge_bias=edge_bias, |
| key_padding_mask=key_padding_mask |
| ) |
| x = x + self.dropout(attn_out) |
| ffn_out = self.ffn(self.ln2(x)) |
| x = x + self.dropout(ffn_out) |
| return x |
|
|
|
|
| class PrivacyGraphTransformer(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.config = config |
| self.hidden_dim = config.hidden_dim |
| self.node_embed = nn.Linear(config.num_abstract_types + 16, config.hidden_dim) |
| self.type_embed = nn.Embedding(config.num_abstract_types, config.hidden_dim) |
| if config.use_centrality_encoding: |
| self.centrality_embed = CentralityEncoding(config.hidden_dim, config.max_degree) |
| if config.use_spatial_encoding: |
| self.spatial_embed = SpatialEncoding(config.num_heads, config.max_spatial_dist) |
| if config.use_edge_encoding: |
| self.edge_embed = EdgeEncoding(config.num_heads, config.max_edge_features) |
| self.layers = nn.ModuleList([ |
| PrivacyAwareBlock(config) for _ in range(config.num_encoder_layers) |
| ]) |
| self.final_ln = nn.LayerNorm(config.hidden_dim) |
| self.apply(self._init_weights) |
|
|
| def _init_weights(self, module): |
| if isinstance(module, nn.Linear): |
| nn.init.normal_(module.weight, std=0.02) |
| if module.bias is not None: |
| nn.init.zeros_(module.bias) |
| elif isinstance(module, nn.LayerNorm): |
| nn.init.ones_(module.weight) |
| nn.init.zeros_(module.bias) |
|
|
| def forward(self, node_features, node_types=None, degree=None, spd=None, |
| edge_index=None, edge_types=None, attention_mask=None, |
| key_padding_mask=None): |
| B, N, F = node_features.shape |
| x = self.node_embed(node_features) |
| if self.config.use_centrality_encoding and degree is not None: |
| x = x + self.centrality_embed(degree) |
| if node_types is not None: |
| x = x + self.type_embed(node_types) |
| spatial_bias = None |
| if self.config.use_spatial_encoding and spd is not None: |
| spatial_bias = self.spatial_embed(spd) |
| edge_bias = None |
| if self.config.use_edge_encoding and edge_index is not None: |
| edge_bias = self.edge_embed(edge_index, edge_types) |
| for layer in self.layers: |
| x = layer(x, attn_mask=attention_mask, spatial_bias=spatial_bias, |
| edge_bias=edge_bias, key_padding_mask=key_padding_mask) |
| x = self.final_ln(x) |
| return x |
|
|