|
|
""" |
|
|
Graph Attention Network (GATv2) Encoder for LILITH. |
|
|
|
|
|
Learns spatial relationships between weather stations using |
|
|
attention-based message passing on a geographic graph. |
|
|
""" |
|
|
|
|
|
import math |
|
|
from typing import Optional, Tuple |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
|
|
|
|
|
|
class GATv2Layer(nn.Module): |
|
|
""" |
|
|
Graph Attention Network v2 layer. |
|
|
|
|
|
Implements the improved attention mechanism from: |
|
|
"How Attentive are Graph Attention Networks?" (Brody et al., 2021) |
|
|
|
|
|
Key improvement: applies attention after the linear transformation, |
|
|
allowing the attention function to be a universal approximator. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
in_dim: int, |
|
|
out_dim: int, |
|
|
num_heads: int = 8, |
|
|
dropout: float = 0.1, |
|
|
edge_dim: Optional[int] = None, |
|
|
residual: bool = True, |
|
|
share_weights: bool = False, |
|
|
): |
|
|
""" |
|
|
Initialize GATv2 layer. |
|
|
|
|
|
Args: |
|
|
in_dim: Input feature dimension |
|
|
out_dim: Output feature dimension (per head) |
|
|
num_heads: Number of attention heads |
|
|
dropout: Dropout probability |
|
|
edge_dim: Edge feature dimension (optional) |
|
|
residual: Whether to use residual connection |
|
|
share_weights: Share weights between source and target transformations |
|
|
""" |
|
|
super().__init__() |
|
|
|
|
|
self.in_dim = in_dim |
|
|
self.out_dim = out_dim |
|
|
self.num_heads = num_heads |
|
|
self.head_dim = out_dim // num_heads |
|
|
self.residual = residual |
|
|
|
|
|
assert out_dim % num_heads == 0, "out_dim must be divisible by num_heads" |
|
|
|
|
|
|
|
|
self.W_src = nn.Linear(in_dim, out_dim, bias=False) |
|
|
if share_weights: |
|
|
self.W_dst = self.W_src |
|
|
else: |
|
|
self.W_dst = nn.Linear(in_dim, out_dim, bias=False) |
|
|
|
|
|
|
|
|
self.attn = nn.Parameter(torch.empty(num_heads, self.head_dim)) |
|
|
|
|
|
|
|
|
if edge_dim is not None: |
|
|
self.edge_proj = nn.Linear(edge_dim, out_dim, bias=False) |
|
|
else: |
|
|
self.edge_proj = None |
|
|
|
|
|
|
|
|
self.out_proj = nn.Linear(out_dim, out_dim) |
|
|
|
|
|
|
|
|
self.norm = nn.LayerNorm(out_dim) |
|
|
self.dropout = nn.Dropout(dropout) |
|
|
self.attn_dropout = nn.Dropout(dropout) |
|
|
|
|
|
|
|
|
if residual and in_dim != out_dim: |
|
|
self.residual_proj = nn.Linear(in_dim, out_dim) |
|
|
else: |
|
|
self.residual_proj = None |
|
|
|
|
|
self._init_weights() |
|
|
|
|
|
def _init_weights(self): |
|
|
"""Initialize weights.""" |
|
|
nn.init.xavier_uniform_(self.W_src.weight) |
|
|
if self.W_dst is not self.W_src: |
|
|
nn.init.xavier_uniform_(self.W_dst.weight) |
|
|
nn.init.xavier_uniform_(self.attn) |
|
|
nn.init.xavier_uniform_(self.out_proj.weight) |
|
|
nn.init.zeros_(self.out_proj.bias) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
x: torch.Tensor, |
|
|
edge_index: torch.Tensor, |
|
|
edge_attr: Optional[torch.Tensor] = None, |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
Forward pass. |
|
|
|
|
|
Args: |
|
|
x: Node features of shape (num_nodes, in_dim) |
|
|
edge_index: Graph connectivity of shape (2, num_edges) |
|
|
edge_attr: Edge features of shape (num_edges, edge_dim) |
|
|
|
|
|
Returns: |
|
|
Updated node features of shape (num_nodes, out_dim) |
|
|
""" |
|
|
num_nodes = x.size(0) |
|
|
src_idx, dst_idx = edge_index[0], edge_index[1] |
|
|
|
|
|
|
|
|
h_src = self.W_src(x) |
|
|
h_dst = self.W_dst(x) |
|
|
|
|
|
|
|
|
h_src = h_src.view(num_nodes, self.num_heads, self.head_dim) |
|
|
h_dst = h_dst.view(num_nodes, self.num_heads, self.head_dim) |
|
|
|
|
|
|
|
|
h_src_edge = h_src[src_idx] |
|
|
h_dst_edge = h_dst[dst_idx] |
|
|
|
|
|
|
|
|
|
|
|
attn_input = h_src_edge + h_dst_edge |
|
|
|
|
|
|
|
|
if edge_attr is not None and self.edge_proj is not None: |
|
|
edge_h = self.edge_proj(edge_attr) |
|
|
edge_h = edge_h.view(-1, self.num_heads, self.head_dim) |
|
|
attn_input = attn_input + edge_h |
|
|
|
|
|
|
|
|
attn_input = F.leaky_relu(attn_input, negative_slope=0.2) |
|
|
attn_scores = (attn_input * self.attn).sum(dim=-1) |
|
|
|
|
|
|
|
|
attn_scores = self._sparse_softmax(attn_scores, dst_idx, num_nodes) |
|
|
attn_scores = self.attn_dropout(attn_scores) |
|
|
|
|
|
|
|
|
|
|
|
messages = h_src_edge * attn_scores.unsqueeze(-1) |
|
|
|
|
|
|
|
|
out = torch.zeros(num_nodes, self.num_heads, self.head_dim, device=x.device) |
|
|
out.scatter_add_(0, dst_idx.view(-1, 1, 1).expand_as(messages), messages) |
|
|
|
|
|
|
|
|
out = out.view(num_nodes, self.out_dim) |
|
|
out = self.out_proj(out) |
|
|
out = self.dropout(out) |
|
|
|
|
|
|
|
|
if self.residual: |
|
|
if self.residual_proj is not None: |
|
|
x = self.residual_proj(x) |
|
|
out = out + x |
|
|
|
|
|
|
|
|
out = self.norm(out) |
|
|
|
|
|
return out |
|
|
|
|
|
def _sparse_softmax( |
|
|
self, |
|
|
scores: torch.Tensor, |
|
|
indices: torch.Tensor, |
|
|
num_nodes: int, |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
Compute softmax over sparse attention scores. |
|
|
|
|
|
Args: |
|
|
scores: Attention scores (num_edges, num_heads) |
|
|
indices: Destination node indices (num_edges,) |
|
|
num_nodes: Total number of nodes |
|
|
|
|
|
Returns: |
|
|
Normalized attention weights (num_edges, num_heads) |
|
|
""" |
|
|
|
|
|
max_scores = torch.zeros(num_nodes, scores.size(1), device=scores.device) |
|
|
max_scores.scatter_reduce_( |
|
|
0, |
|
|
indices.view(-1, 1).expand_as(scores), |
|
|
scores, |
|
|
reduce="amax", |
|
|
include_self=False, |
|
|
) |
|
|
scores = scores - max_scores[indices] |
|
|
|
|
|
|
|
|
exp_scores = torch.exp(scores) |
|
|
sum_exp = torch.zeros(num_nodes, scores.size(1), device=scores.device) |
|
|
sum_exp.scatter_add_(0, indices.view(-1, 1).expand_as(exp_scores), exp_scores) |
|
|
|
|
|
|
|
|
return exp_scores / (sum_exp[indices] + 1e-8) |
|
|
|
|
|
|
|
|
class GATEncoder(nn.Module): |
|
|
""" |
|
|
Multi-layer Graph Attention Network encoder. |
|
|
|
|
|
Processes station observations through multiple GAT layers to capture |
|
|
spatial dependencies between weather stations. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
input_dim: int, |
|
|
hidden_dim: int = 256, |
|
|
output_dim: int = 256, |
|
|
num_layers: int = 3, |
|
|
num_heads: int = 8, |
|
|
dropout: float = 0.1, |
|
|
edge_dim: Optional[int] = None, |
|
|
): |
|
|
""" |
|
|
Initialize GAT encoder. |
|
|
|
|
|
Args: |
|
|
input_dim: Input feature dimension |
|
|
hidden_dim: Hidden dimension |
|
|
output_dim: Output dimension |
|
|
num_layers: Number of GAT layers |
|
|
num_heads: Number of attention heads |
|
|
dropout: Dropout probability |
|
|
edge_dim: Edge feature dimension |
|
|
""" |
|
|
super().__init__() |
|
|
|
|
|
self.input_dim = input_dim |
|
|
self.hidden_dim = hidden_dim |
|
|
self.output_dim = output_dim |
|
|
self.num_layers = num_layers |
|
|
|
|
|
|
|
|
self.input_proj = nn.Linear(input_dim, hidden_dim) |
|
|
|
|
|
|
|
|
self.layers = nn.ModuleList() |
|
|
for i in range(num_layers): |
|
|
in_dim = hidden_dim |
|
|
out_dim = output_dim if i == num_layers - 1 else hidden_dim |
|
|
|
|
|
self.layers.append( |
|
|
GATv2Layer( |
|
|
in_dim=in_dim, |
|
|
out_dim=out_dim, |
|
|
num_heads=num_heads, |
|
|
dropout=dropout, |
|
|
edge_dim=edge_dim if i == 0 else None, |
|
|
residual=True, |
|
|
) |
|
|
) |
|
|
|
|
|
|
|
|
self.output_proj = nn.Linear(output_dim, output_dim) |
|
|
self.output_norm = nn.LayerNorm(output_dim) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
x: torch.Tensor, |
|
|
edge_index: torch.Tensor, |
|
|
edge_attr: Optional[torch.Tensor] = None, |
|
|
batch: Optional[torch.Tensor] = None, |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
Encode station features through GAT layers. |
|
|
|
|
|
Args: |
|
|
x: Node features of shape (num_nodes, input_dim) |
|
|
edge_index: Graph connectivity of shape (2, num_edges) |
|
|
edge_attr: Edge features of shape (num_edges, edge_dim) |
|
|
batch: Batch assignment of shape (num_nodes,) |
|
|
|
|
|
Returns: |
|
|
Encoded features of shape (num_nodes, output_dim) |
|
|
""" |
|
|
|
|
|
h = self.input_proj(x) |
|
|
|
|
|
|
|
|
for i, layer in enumerate(self.layers): |
|
|
h = layer( |
|
|
h, |
|
|
edge_index, |
|
|
edge_attr if i == 0 else None, |
|
|
) |
|
|
|
|
|
|
|
|
h = self.output_proj(h) |
|
|
h = self.output_norm(h) |
|
|
|
|
|
return h |
|
|
|
|
|
def forward_batched( |
|
|
self, |
|
|
x: torch.Tensor, |
|
|
edge_index: torch.Tensor, |
|
|
edge_attr: Optional[torch.Tensor] = None, |
|
|
batch: Optional[torch.Tensor] = None, |
|
|
return_attention: bool = False, |
|
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: |
|
|
""" |
|
|
Forward pass with batched graphs. |
|
|
|
|
|
This handles multiple graphs in a single batch by using |
|
|
the batch tensor to track which nodes belong to which graph. |
|
|
|
|
|
Args: |
|
|
x: Batched node features |
|
|
edge_index: Batched edge indices |
|
|
edge_attr: Batched edge attributes |
|
|
batch: Batch assignment tensor |
|
|
return_attention: Whether to return attention weights |
|
|
|
|
|
Returns: |
|
|
Encoded features and optionally attention weights |
|
|
""" |
|
|
h = self.forward(x, edge_index, edge_attr, batch) |
|
|
|
|
|
if return_attention: |
|
|
|
|
|
|
|
|
return h, None |
|
|
|
|
|
return h, None |
|
|
|