import torch import torch.nn as nn from torch_geometric.nn import GATConv, GATv2Conv, BatchNorm class ResidualGATBlock(nn.Module): def __init__(self, in_channels, hidden_channels, heads=8, dropout=0.5, v2=False): super().__init__() Conv = GATv2Conv if v2 else GATConv self.conv = Conv(in_channels, hidden_channels, heads=heads, dropout=dropout) self.bn = BatchNorm(hidden_channels * heads) self.act = nn.ReLU() self.dropout = nn.Dropout(dropout) self.res_proj = None out_dim = hidden_channels * heads if in_channels != out_dim: self.res_proj = nn.Linear(in_channels, out_dim) def forward(self, x, edge_index): identity = x out = self.conv(x, edge_index) out = self.bn(out) out = self.act(out) out = self.dropout(out) if self.res_proj is not None: identity = self.res_proj(identity) return out + identity class GATBaseline(nn.Module): def __init__(self, in_channels, hidden_channels=128, heads=8, num_blocks=2, dropout=0.5): super().__init__() layers = [] c_in = in_channels for _ in range(num_blocks): layers.append(ResidualGATBlock(c_in, hidden_channels, heads=heads, dropout=dropout, v2=False)) c_in = hidden_channels * heads self.blocks = nn.ModuleList(layers) self.dropout = nn.Dropout(dropout) self.out_conv = GATConv(c_in, 1, heads=1, concat=False, dropout=dropout) def forward(self, x, edge_index): for block in self.blocks: x = block(x, edge_index) x = self.dropout(x) out = self.out_conv(x, edge_index) return out.view(-1) class GATv2Enhanced(nn.Module): def __init__(self, in_channels, hidden_channels=128, heads=8, num_blocks=2, dropout=0.5): super().__init__() layers = [] c_in = in_channels for _ in range(num_blocks): layers.append(ResidualGATBlock(c_in, hidden_channels, heads=heads, dropout=dropout, v2=True)) c_in = hidden_channels * heads self.blocks = nn.ModuleList(layers) self.dropout = nn.Dropout(dropout) self.out_conv = GATv2Conv(c_in, 1, heads=1, concat=False, dropout=dropout) def forward(self, x, edge_index): for block in self.blocks: x = block(x, edge_index) x = self.dropout(x) out = self.out_conv(x, edge_index) return out.view(-1) class AdapterWrapper(nn.Module): def __init__(self, in_dim_new, expected_in_dim, core_model): super().__init__() if in_dim_new != expected_in_dim: self.adapter = nn.Linear(in_dim_new, expected_in_dim, bias=True) else: self.adapter = None self.core = core_model def forward(self, x, edge_index): if self.adapter is not None: x = self.adapter(x) return self.core(x, edge_index)