consigcody94's picture
Upload source code and documentation
8bcb60f verified
"""
Graph Attention Network (GATv2) Encoder for LILITH.
Learns spatial relationships between weather stations using
attention-based message passing on a geographic graph.
"""
import math
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
class GATv2Layer(nn.Module):
"""
Graph Attention Network v2 layer.
Implements the improved attention mechanism from:
"How Attentive are Graph Attention Networks?" (Brody et al., 2021)
Key improvement: applies attention after the linear transformation,
allowing the attention function to be a universal approximator.
"""
def __init__(
self,
in_dim: int,
out_dim: int,
num_heads: int = 8,
dropout: float = 0.1,
edge_dim: Optional[int] = None,
residual: bool = True,
share_weights: bool = False,
):
"""
Initialize GATv2 layer.
Args:
in_dim: Input feature dimension
out_dim: Output feature dimension (per head)
num_heads: Number of attention heads
dropout: Dropout probability
edge_dim: Edge feature dimension (optional)
residual: Whether to use residual connection
share_weights: Share weights between source and target transformations
"""
super().__init__()
self.in_dim = in_dim
self.out_dim = out_dim
self.num_heads = num_heads
self.head_dim = out_dim // num_heads
self.residual = residual
assert out_dim % num_heads == 0, "out_dim must be divisible by num_heads"
# Linear transformations for source and target nodes
self.W_src = nn.Linear(in_dim, out_dim, bias=False)
if share_weights:
self.W_dst = self.W_src
else:
self.W_dst = nn.Linear(in_dim, out_dim, bias=False)
# Attention parameters (one per head)
self.attn = nn.Parameter(torch.empty(num_heads, self.head_dim))
# Edge feature projection (optional)
if edge_dim is not None:
self.edge_proj = nn.Linear(edge_dim, out_dim, bias=False)
else:
self.edge_proj = None
# Output projection
self.out_proj = nn.Linear(out_dim, out_dim)
# Layer norm and dropout
self.norm = nn.LayerNorm(out_dim)
self.dropout = nn.Dropout(dropout)
self.attn_dropout = nn.Dropout(dropout)
# Residual projection if dimensions don't match
if residual and in_dim != out_dim:
self.residual_proj = nn.Linear(in_dim, out_dim)
else:
self.residual_proj = None
self._init_weights()
def _init_weights(self):
"""Initialize weights."""
nn.init.xavier_uniform_(self.W_src.weight)
if self.W_dst is not self.W_src:
nn.init.xavier_uniform_(self.W_dst.weight)
nn.init.xavier_uniform_(self.attn)
nn.init.xavier_uniform_(self.out_proj.weight)
nn.init.zeros_(self.out_proj.bias)
def forward(
self,
x: torch.Tensor,
edge_index: torch.Tensor,
edge_attr: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Forward pass.
Args:
x: Node features of shape (num_nodes, in_dim)
edge_index: Graph connectivity of shape (2, num_edges)
edge_attr: Edge features of shape (num_edges, edge_dim)
Returns:
Updated node features of shape (num_nodes, out_dim)
"""
num_nodes = x.size(0)
src_idx, dst_idx = edge_index[0], edge_index[1]
# Linear transformations
h_src = self.W_src(x) # (num_nodes, out_dim)
h_dst = self.W_dst(x) # (num_nodes, out_dim)
# Reshape for multi-head attention
h_src = h_src.view(num_nodes, self.num_heads, self.head_dim)
h_dst = h_dst.view(num_nodes, self.num_heads, self.head_dim)
# Get source and destination features for each edge
h_src_edge = h_src[src_idx] # (num_edges, num_heads, head_dim)
h_dst_edge = h_dst[dst_idx] # (num_edges, num_heads, head_dim)
# GATv2 attention: apply attention after transformation
# a(Wh_i || Wh_j) -> LeakyReLU(a * (Wh_i + Wh_j))
attn_input = h_src_edge + h_dst_edge # (num_edges, num_heads, head_dim)
# Add edge features if available
if edge_attr is not None and self.edge_proj is not None:
edge_h = self.edge_proj(edge_attr) # (num_edges, out_dim)
edge_h = edge_h.view(-1, self.num_heads, self.head_dim)
attn_input = attn_input + edge_h
# Compute attention scores
attn_input = F.leaky_relu(attn_input, negative_slope=0.2)
attn_scores = (attn_input * self.attn).sum(dim=-1) # (num_edges, num_heads)
# Normalize attention scores using softmax over neighbors
attn_scores = self._sparse_softmax(attn_scores, dst_idx, num_nodes)
attn_scores = self.attn_dropout(attn_scores)
# Aggregate messages
# Weighted sum of source features
messages = h_src_edge * attn_scores.unsqueeze(-1) # (num_edges, num_heads, head_dim)
# Scatter-add messages to destination nodes
out = torch.zeros(num_nodes, self.num_heads, self.head_dim, device=x.device)
out.scatter_add_(0, dst_idx.view(-1, 1, 1).expand_as(messages), messages)
# Reshape and project
out = out.view(num_nodes, self.out_dim)
out = self.out_proj(out)
out = self.dropout(out)
# Residual connection
if self.residual:
if self.residual_proj is not None:
x = self.residual_proj(x)
out = out + x
# Layer norm
out = self.norm(out)
return out
def _sparse_softmax(
self,
scores: torch.Tensor,
indices: torch.Tensor,
num_nodes: int,
) -> torch.Tensor:
"""
Compute softmax over sparse attention scores.
Args:
scores: Attention scores (num_edges, num_heads)
indices: Destination node indices (num_edges,)
num_nodes: Total number of nodes
Returns:
Normalized attention weights (num_edges, num_heads)
"""
# Compute max for numerical stability
max_scores = torch.zeros(num_nodes, scores.size(1), device=scores.device)
max_scores.scatter_reduce_(
0,
indices.view(-1, 1).expand_as(scores),
scores,
reduce="amax",
include_self=False,
)
scores = scores - max_scores[indices]
# Exp and sum
exp_scores = torch.exp(scores)
sum_exp = torch.zeros(num_nodes, scores.size(1), device=scores.device)
sum_exp.scatter_add_(0, indices.view(-1, 1).expand_as(exp_scores), exp_scores)
# Normalize
return exp_scores / (sum_exp[indices] + 1e-8)
class GATEncoder(nn.Module):
"""
Multi-layer Graph Attention Network encoder.
Processes station observations through multiple GAT layers to capture
spatial dependencies between weather stations.
"""
def __init__(
self,
input_dim: int,
hidden_dim: int = 256,
output_dim: int = 256,
num_layers: int = 3,
num_heads: int = 8,
dropout: float = 0.1,
edge_dim: Optional[int] = None,
):
"""
Initialize GAT encoder.
Args:
input_dim: Input feature dimension
hidden_dim: Hidden dimension
output_dim: Output dimension
num_layers: Number of GAT layers
num_heads: Number of attention heads
dropout: Dropout probability
edge_dim: Edge feature dimension
"""
super().__init__()
self.input_dim = input_dim
self.hidden_dim = hidden_dim
self.output_dim = output_dim
self.num_layers = num_layers
# Input projection
self.input_proj = nn.Linear(input_dim, hidden_dim)
# GAT layers
self.layers = nn.ModuleList()
for i in range(num_layers):
in_dim = hidden_dim
out_dim = output_dim if i == num_layers - 1 else hidden_dim
self.layers.append(
GATv2Layer(
in_dim=in_dim,
out_dim=out_dim,
num_heads=num_heads,
dropout=dropout,
edge_dim=edge_dim if i == 0 else None, # Only use edge features in first layer
residual=True,
)
)
# Output projection
self.output_proj = nn.Linear(output_dim, output_dim)
self.output_norm = nn.LayerNorm(output_dim)
def forward(
self,
x: torch.Tensor,
edge_index: torch.Tensor,
edge_attr: Optional[torch.Tensor] = None,
batch: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Encode station features through GAT layers.
Args:
x: Node features of shape (num_nodes, input_dim)
edge_index: Graph connectivity of shape (2, num_edges)
edge_attr: Edge features of shape (num_edges, edge_dim)
batch: Batch assignment of shape (num_nodes,)
Returns:
Encoded features of shape (num_nodes, output_dim)
"""
# Input projection
h = self.input_proj(x)
# Apply GAT layers
for i, layer in enumerate(self.layers):
h = layer(
h,
edge_index,
edge_attr if i == 0 else None,
)
# Output projection
h = self.output_proj(h)
h = self.output_norm(h)
return h
def forward_batched(
self,
x: torch.Tensor,
edge_index: torch.Tensor,
edge_attr: Optional[torch.Tensor] = None,
batch: Optional[torch.Tensor] = None,
return_attention: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""
Forward pass with batched graphs.
This handles multiple graphs in a single batch by using
the batch tensor to track which nodes belong to which graph.
Args:
x: Batched node features
edge_index: Batched edge indices
edge_attr: Batched edge attributes
batch: Batch assignment tensor
return_attention: Whether to return attention weights
Returns:
Encoded features and optionally attention weights
"""
h = self.forward(x, edge_index, edge_attr, batch)
if return_attention:
# Attention weights from last layer would go here
# For now, return None
return h, None
return h, None