File size: 2,963 Bytes
db886e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import torch
import torch.nn as nn
from torch_geometric.nn import GATConv, GATv2Conv, BatchNorm

class ResidualGATBlock(nn.Module):
    def __init__(self, in_channels, hidden_channels, heads=8, dropout=0.5, v2=False):
        super().__init__()
        Conv = GATv2Conv if v2 else GATConv
        self.conv = Conv(in_channels, hidden_channels, heads=heads, dropout=dropout)
        self.bn = BatchNorm(hidden_channels * heads)
        self.act = nn.ReLU()
        self.dropout = nn.Dropout(dropout)
        self.res_proj = None
        out_dim = hidden_channels * heads
        if in_channels != out_dim:
            self.res_proj = nn.Linear(in_channels, out_dim)

    def forward(self, x, edge_index):
        identity = x
        out = self.conv(x, edge_index)
        out = self.bn(out)
        out = self.act(out)
        out = self.dropout(out)
        if self.res_proj is not None:
            identity = self.res_proj(identity)
        return out + identity

class GATBaseline(nn.Module):
    def __init__(self, in_channels, hidden_channels=128, heads=8, num_blocks=2, dropout=0.5):
        super().__init__()
        layers = []
        c_in = in_channels
        for _ in range(num_blocks):
            layers.append(ResidualGATBlock(c_in, hidden_channels, heads=heads, dropout=dropout, v2=False))
            c_in = hidden_channels * heads
        self.blocks = nn.ModuleList(layers)
        self.dropout = nn.Dropout(dropout)
        self.out_conv = GATConv(c_in, 1, heads=1, concat=False, dropout=dropout)

    def forward(self, x, edge_index):
        for block in self.blocks:
            x = block(x, edge_index)
        x = self.dropout(x)
        out = self.out_conv(x, edge_index)
        return out.view(-1)

class GATv2Enhanced(nn.Module):
    def __init__(self, in_channels, hidden_channels=128, heads=8, num_blocks=2, dropout=0.5):
        super().__init__()
        layers = []
        c_in = in_channels
        for _ in range(num_blocks):
            layers.append(ResidualGATBlock(c_in, hidden_channels, heads=heads, dropout=dropout, v2=True))
            c_in = hidden_channels * heads
        self.blocks = nn.ModuleList(layers)
        self.dropout = nn.Dropout(dropout)
        self.out_conv = GATv2Conv(c_in, 1, heads=1, concat=False, dropout=dropout)

    def forward(self, x, edge_index):
        for block in self.blocks:
            x = block(x, edge_index)
        x = self.dropout(x)
        out = self.out_conv(x, edge_index)
        return out.view(-1)

class AdapterWrapper(nn.Module):
    def __init__(self, in_dim_new, expected_in_dim, core_model):
        super().__init__()
        if in_dim_new != expected_in_dim:
            self.adapter = nn.Linear(in_dim_new, expected_in_dim, bias=True)
        else:
            self.adapter = None
        self.core = core_model

    def forward(self, x, edge_index):
        if self.adapter is not None:
            x = self.adapter(x)
        return self.core(x, edge_index)