StruCTA / structa /encoder.py
YOUSSEF88's picture
Upload structa/encoder.py
9e564d6 verified
"""
PrivacyGraphTransformer: Graphormer-style transformer operating on AMR nodes.
No raw text positions used — uses structural encodings (centrality, spatial, edge).
"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple
class CentralityEncoding(nn.Module):
"""Learnable degree-based centrality embedding (Graphormer)."""
def __init__(self, hidden_dim: int, max_degree: int = 512):
super().__init__()
self.hidden_dim = hidden_dim
self.max_degree = max_degree
self.degree_embed = nn.Embedding(max_degree + 1, hidden_dim)
nn.init.normal_(self.degree_embed.weight, std=0.02)
def forward(self, degree: torch.Tensor) -> torch.Tensor:
degree = degree.clamp(0, self.max_degree)
return self.degree_embed(degree)
class SpatialEncoding(nn.Module):
"""Learnable shortest-path-distance bias in attention (Graphormer)."""
def __init__(self, num_heads: int, max_distance: int = 128):
super().__init__()
self.num_heads = num_heads
self.max_distance = max_distance
self.spatial_bias = nn.Embedding(max_distance + 2, num_heads)
nn.init.zeros_(self.spatial_bias.weight)
def forward(self, spd: torch.Tensor) -> torch.Tensor:
spd = spd.clamp(-1, self.max_distance + 1)
spd = torch.where(spd < 0, self.max_distance + 1, spd)
bias = self.spatial_bias(spd)
return bias.permute(0, 3, 1, 2)
class EdgeEncoding(nn.Module):
"""Edge feature attention bias (Graphormer)."""
def __init__(self, num_heads: int, max_edge_features: int = 32):
super().__init__()
self.num_heads = num_heads
self.edge_embed = nn.Embedding(max_edge_features + 1, num_heads)
nn.init.zeros_(self.edge_embed.weight)
def forward(self, edge_index, edge_types, spd_paths=None):
if spd_paths is None:
return 0.0
return 0.0
class PrivacyAwareSelfAttention(nn.Module):
"""Multi-head self-attention with structural bias injection."""
def __init__(self, config):
super().__init__()
self.hidden_dim = config.hidden_dim
self.num_heads = config.num_heads
self.head_dim = config.hidden_dim // config.num_heads
self.scale = 1.0 / math.sqrt(self.head_dim)
self.q_proj = nn.Linear(config.hidden_dim, config.hidden_dim)
self.k_proj = nn.Linear(config.hidden_dim, config.hidden_dim)
self.v_proj = nn.Linear(config.hidden_dim, config.hidden_dim)
self.out_proj = nn.Linear(config.hidden_dim, config.hidden_dim)
self.dropout = nn.Dropout(config.attention_dropout)
self.use_spatial = config.use_spatial_encoding
self.use_edge = config.use_edge_encoding
def forward(self, x, attn_mask=None, spatial_bias=None, edge_bias=None,
key_padding_mask=None):
B, N, D = x.shape
q = self.q_proj(x).view(B, N, self.num_heads, self.head_dim).transpose(1, 2)
k = self.k_proj(x).view(B, N, self.num_heads, self.head_dim).transpose(1, 2)
v = self.v_proj(x).view(B, N, self.num_heads, self.head_dim).transpose(1, 2)
attn_scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale
if self.use_spatial and spatial_bias is not None:
attn_scores = attn_scores + spatial_bias
if self.use_edge and edge_bias is not None and edge_bias != 0.0:
attn_scores = attn_scores + edge_bias
if attn_mask is not None:
attn_scores = attn_scores.masked_fill(attn_mask == 0, float("-inf"))
if key_padding_mask is not None:
mask = key_padding_mask.unsqueeze(1).unsqueeze(2)
attn_scores = attn_scores.masked_fill(mask == 1, float("-inf"))
attn_weights = F.softmax(attn_scores, dim=-1)
attn_weights = self.dropout(attn_weights)
out = torch.matmul(attn_weights, v)
out = out.transpose(1, 2).contiguous().view(B, N, D)
return self.out_proj(out)
class PrivacyFeedForward(nn.Module):
def __init__(self, hidden_dim: int, ffn_dim: int, dropout: float = 0.1):
super().__init__()
self.fc1 = nn.Linear(hidden_dim, ffn_dim)
self.fc2 = nn.Linear(ffn_dim, hidden_dim)
self.dropout = nn.Dropout(dropout)
self.activation = nn.GELU()
def forward(self, x):
x = self.fc1(x)
x = self.activation(x)
x = self.dropout(x)
x = self.fc2(x)
return x
class PrivacyAwareBlock(nn.Module):
def __init__(self, config):
super().__init__()
self.attn = PrivacyAwareSelfAttention(config)
self.ffn = PrivacyFeedForward(config.hidden_dim, config.ffn_dim, config.dropout)
self.ln1 = nn.LayerNorm(config.hidden_dim)
self.ln2 = nn.LayerNorm(config.hidden_dim)
self.dropout = nn.Dropout(config.dropout)
self.use_dp = config.use_dp_training
self.clip_norm = config.dp_clip_norm
def forward(self, x, attn_mask=None, spatial_bias=None, edge_bias=None,
key_padding_mask=None):
attn_out = self.attn(
self.ln1(x), attn_mask=attn_mask,
spatial_bias=spatial_bias, edge_bias=edge_bias,
key_padding_mask=key_padding_mask
)
x = x + self.dropout(attn_out)
ffn_out = self.ffn(self.ln2(x))
x = x + self.dropout(ffn_out)
return x
class PrivacyGraphTransformer(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.hidden_dim = config.hidden_dim
self.node_embed = nn.Linear(config.num_abstract_types + 16, config.hidden_dim)
self.type_embed = nn.Embedding(config.num_abstract_types, config.hidden_dim)
if config.use_centrality_encoding:
self.centrality_embed = CentralityEncoding(config.hidden_dim, config.max_degree)
if config.use_spatial_encoding:
self.spatial_embed = SpatialEncoding(config.num_heads, config.max_spatial_dist)
if config.use_edge_encoding:
self.edge_embed = EdgeEncoding(config.num_heads, config.max_edge_features)
self.layers = nn.ModuleList([
PrivacyAwareBlock(config) for _ in range(config.num_encoder_layers)
])
self.final_ln = nn.LayerNorm(config.hidden_dim)
self.apply(self._init_weights)
def _init_weights(self, module):
if isinstance(module, nn.Linear):
nn.init.normal_(module.weight, std=0.02)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.LayerNorm):
nn.init.ones_(module.weight)
nn.init.zeros_(module.bias)
def forward(self, node_features, node_types=None, degree=None, spd=None,
edge_index=None, edge_types=None, attention_mask=None,
key_padding_mask=None):
B, N, F = node_features.shape
x = self.node_embed(node_features)
if self.config.use_centrality_encoding and degree is not None:
x = x + self.centrality_embed(degree)
if node_types is not None:
x = x + self.type_embed(node_types)
spatial_bias = None
if self.config.use_spatial_encoding and spd is not None:
spatial_bias = self.spatial_embed(spd)
edge_bias = None
if self.config.use_edge_encoding and edge_index is not None:
edge_bias = self.edge_embed(edge_index, edge_types)
for layer in self.layers:
x = layer(x, attn_mask=attention_mask, spatial_bias=spatial_bias,
edge_bias=edge_bias, key_padding_mask=key_padding_mask)
x = self.final_ln(x)
return x