File size: 10,880 Bytes
8bcb60f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
"""
Graph Attention Network (GATv2) Encoder for LILITH.

Learns spatial relationships between weather stations using
attention-based message passing on a geographic graph.
"""

import math
from typing import Optional, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F


class GATv2Layer(nn.Module):
    """
    Graph Attention Network v2 layer.

    Implements the improved attention mechanism from:
    "How Attentive are Graph Attention Networks?" (Brody et al., 2021)

    Key improvement: applies attention after the linear transformation,
    allowing the attention function to be a universal approximator.
    """

    def __init__(
        self,
        in_dim: int,
        out_dim: int,
        num_heads: int = 8,
        dropout: float = 0.1,
        edge_dim: Optional[int] = None,
        residual: bool = True,
        share_weights: bool = False,
    ):
        """
        Initialize GATv2 layer.

        Args:
            in_dim: Input feature dimension
            out_dim: Output feature dimension (per head)
            num_heads: Number of attention heads
            dropout: Dropout probability
            edge_dim: Edge feature dimension (optional)
            residual: Whether to use residual connection
            share_weights: Share weights between source and target transformations
        """
        super().__init__()

        self.in_dim = in_dim
        self.out_dim = out_dim
        self.num_heads = num_heads
        self.head_dim = out_dim // num_heads
        self.residual = residual

        assert out_dim % num_heads == 0, "out_dim must be divisible by num_heads"

        # Linear transformations for source and target nodes
        self.W_src = nn.Linear(in_dim, out_dim, bias=False)
        if share_weights:
            self.W_dst = self.W_src
        else:
            self.W_dst = nn.Linear(in_dim, out_dim, bias=False)

        # Attention parameters (one per head)
        self.attn = nn.Parameter(torch.empty(num_heads, self.head_dim))

        # Edge feature projection (optional)
        if edge_dim is not None:
            self.edge_proj = nn.Linear(edge_dim, out_dim, bias=False)
        else:
            self.edge_proj = None

        # Output projection
        self.out_proj = nn.Linear(out_dim, out_dim)

        # Layer norm and dropout
        self.norm = nn.LayerNorm(out_dim)
        self.dropout = nn.Dropout(dropout)
        self.attn_dropout = nn.Dropout(dropout)

        # Residual projection if dimensions don't match
        if residual and in_dim != out_dim:
            self.residual_proj = nn.Linear(in_dim, out_dim)
        else:
            self.residual_proj = None

        self._init_weights()

    def _init_weights(self):
        """Initialize weights."""
        nn.init.xavier_uniform_(self.W_src.weight)
        if self.W_dst is not self.W_src:
            nn.init.xavier_uniform_(self.W_dst.weight)
        nn.init.xavier_uniform_(self.attn)
        nn.init.xavier_uniform_(self.out_proj.weight)
        nn.init.zeros_(self.out_proj.bias)

    def forward(
        self,
        x: torch.Tensor,
        edge_index: torch.Tensor,
        edge_attr: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        """
        Forward pass.

        Args:
            x: Node features of shape (num_nodes, in_dim)
            edge_index: Graph connectivity of shape (2, num_edges)
            edge_attr: Edge features of shape (num_edges, edge_dim)

        Returns:
            Updated node features of shape (num_nodes, out_dim)
        """
        num_nodes = x.size(0)
        src_idx, dst_idx = edge_index[0], edge_index[1]

        # Linear transformations
        h_src = self.W_src(x)  # (num_nodes, out_dim)
        h_dst = self.W_dst(x)  # (num_nodes, out_dim)

        # Reshape for multi-head attention
        h_src = h_src.view(num_nodes, self.num_heads, self.head_dim)
        h_dst = h_dst.view(num_nodes, self.num_heads, self.head_dim)

        # Get source and destination features for each edge
        h_src_edge = h_src[src_idx]  # (num_edges, num_heads, head_dim)
        h_dst_edge = h_dst[dst_idx]  # (num_edges, num_heads, head_dim)

        # GATv2 attention: apply attention after transformation
        # a(Wh_i || Wh_j) -> LeakyReLU(a * (Wh_i + Wh_j))
        attn_input = h_src_edge + h_dst_edge  # (num_edges, num_heads, head_dim)

        # Add edge features if available
        if edge_attr is not None and self.edge_proj is not None:
            edge_h = self.edge_proj(edge_attr)  # (num_edges, out_dim)
            edge_h = edge_h.view(-1, self.num_heads, self.head_dim)
            attn_input = attn_input + edge_h

        # Compute attention scores
        attn_input = F.leaky_relu(attn_input, negative_slope=0.2)
        attn_scores = (attn_input * self.attn).sum(dim=-1)  # (num_edges, num_heads)

        # Normalize attention scores using softmax over neighbors
        attn_scores = self._sparse_softmax(attn_scores, dst_idx, num_nodes)
        attn_scores = self.attn_dropout(attn_scores)

        # Aggregate messages
        # Weighted sum of source features
        messages = h_src_edge * attn_scores.unsqueeze(-1)  # (num_edges, num_heads, head_dim)

        # Scatter-add messages to destination nodes
        out = torch.zeros(num_nodes, self.num_heads, self.head_dim, device=x.device)
        out.scatter_add_(0, dst_idx.view(-1, 1, 1).expand_as(messages), messages)

        # Reshape and project
        out = out.view(num_nodes, self.out_dim)
        out = self.out_proj(out)
        out = self.dropout(out)

        # Residual connection
        if self.residual:
            if self.residual_proj is not None:
                x = self.residual_proj(x)
            out = out + x

        # Layer norm
        out = self.norm(out)

        return out

    def _sparse_softmax(
        self,
        scores: torch.Tensor,
        indices: torch.Tensor,
        num_nodes: int,
    ) -> torch.Tensor:
        """
        Compute softmax over sparse attention scores.

        Args:
            scores: Attention scores (num_edges, num_heads)
            indices: Destination node indices (num_edges,)
            num_nodes: Total number of nodes

        Returns:
            Normalized attention weights (num_edges, num_heads)
        """
        # Compute max for numerical stability
        max_scores = torch.zeros(num_nodes, scores.size(1), device=scores.device)
        max_scores.scatter_reduce_(
            0,
            indices.view(-1, 1).expand_as(scores),
            scores,
            reduce="amax",
            include_self=False,
        )
        scores = scores - max_scores[indices]

        # Exp and sum
        exp_scores = torch.exp(scores)
        sum_exp = torch.zeros(num_nodes, scores.size(1), device=scores.device)
        sum_exp.scatter_add_(0, indices.view(-1, 1).expand_as(exp_scores), exp_scores)

        # Normalize
        return exp_scores / (sum_exp[indices] + 1e-8)


class GATEncoder(nn.Module):
    """
    Multi-layer Graph Attention Network encoder.

    Processes station observations through multiple GAT layers to capture
    spatial dependencies between weather stations.
    """

    def __init__(
        self,
        input_dim: int,
        hidden_dim: int = 256,
        output_dim: int = 256,
        num_layers: int = 3,
        num_heads: int = 8,
        dropout: float = 0.1,
        edge_dim: Optional[int] = None,
    ):
        """
        Initialize GAT encoder.

        Args:
            input_dim: Input feature dimension
            hidden_dim: Hidden dimension
            output_dim: Output dimension
            num_layers: Number of GAT layers
            num_heads: Number of attention heads
            dropout: Dropout probability
            edge_dim: Edge feature dimension
        """
        super().__init__()

        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim
        self.num_layers = num_layers

        # Input projection
        self.input_proj = nn.Linear(input_dim, hidden_dim)

        # GAT layers
        self.layers = nn.ModuleList()
        for i in range(num_layers):
            in_dim = hidden_dim
            out_dim = output_dim if i == num_layers - 1 else hidden_dim

            self.layers.append(
                GATv2Layer(
                    in_dim=in_dim,
                    out_dim=out_dim,
                    num_heads=num_heads,
                    dropout=dropout,
                    edge_dim=edge_dim if i == 0 else None,  # Only use edge features in first layer
                    residual=True,
                )
            )

        # Output projection
        self.output_proj = nn.Linear(output_dim, output_dim)
        self.output_norm = nn.LayerNorm(output_dim)

    def forward(
        self,
        x: torch.Tensor,
        edge_index: torch.Tensor,
        edge_attr: Optional[torch.Tensor] = None,
        batch: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        """
        Encode station features through GAT layers.

        Args:
            x: Node features of shape (num_nodes, input_dim)
            edge_index: Graph connectivity of shape (2, num_edges)
            edge_attr: Edge features of shape (num_edges, edge_dim)
            batch: Batch assignment of shape (num_nodes,)

        Returns:
            Encoded features of shape (num_nodes, output_dim)
        """
        # Input projection
        h = self.input_proj(x)

        # Apply GAT layers
        for i, layer in enumerate(self.layers):
            h = layer(
                h,
                edge_index,
                edge_attr if i == 0 else None,
            )

        # Output projection
        h = self.output_proj(h)
        h = self.output_norm(h)

        return h

    def forward_batched(
        self,
        x: torch.Tensor,
        edge_index: torch.Tensor,
        edge_attr: Optional[torch.Tensor] = None,
        batch: Optional[torch.Tensor] = None,
        return_attention: bool = False,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
        """
        Forward pass with batched graphs.

        This handles multiple graphs in a single batch by using
        the batch tensor to track which nodes belong to which graph.

        Args:
            x: Batched node features
            edge_index: Batched edge indices
            edge_attr: Batched edge attributes
            batch: Batch assignment tensor
            return_attention: Whether to return attention weights

        Returns:
            Encoded features and optionally attention weights
        """
        h = self.forward(x, edge_index, edge_attr, batch)

        if return_attention:
            # Attention weights from last layer would go here
            # For now, return None
            return h, None

        return h, None