Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| from torch_geometric.nn import GATConv, GATv2Conv, BatchNorm | |
| class ResidualGATBlock(nn.Module): | |
| def __init__(self, in_channels, hidden_channels, heads=8, dropout=0.5, v2=False): | |
| super().__init__() | |
| Conv = GATv2Conv if v2 else GATConv | |
| self.conv = Conv(in_channels, hidden_channels, heads=heads, dropout=dropout) | |
| self.bn = BatchNorm(hidden_channels * heads) | |
| self.act = nn.ReLU() | |
| self.dropout = nn.Dropout(dropout) | |
| self.res_proj = None | |
| out_dim = hidden_channels * heads | |
| if in_channels != out_dim: | |
| self.res_proj = nn.Linear(in_channels, out_dim) | |
| def forward(self, x, edge_index): | |
| identity = x | |
| out = self.conv(x, edge_index) | |
| out = self.bn(out) | |
| out = self.act(out) | |
| out = self.dropout(out) | |
| if self.res_proj is not None: | |
| identity = self.res_proj(identity) | |
| return out + identity | |
| class GATBaseline(nn.Module): | |
| def __init__(self, in_channels, hidden_channels=128, heads=8, num_blocks=2, dropout=0.5): | |
| super().__init__() | |
| layers = [] | |
| c_in = in_channels | |
| for _ in range(num_blocks): | |
| layers.append(ResidualGATBlock(c_in, hidden_channels, heads=heads, dropout=dropout, v2=False)) | |
| c_in = hidden_channels * heads | |
| self.blocks = nn.ModuleList(layers) | |
| self.dropout = nn.Dropout(dropout) | |
| self.out_conv = GATConv(c_in, 1, heads=1, concat=False, dropout=dropout) | |
| def forward(self, x, edge_index): | |
| for block in self.blocks: | |
| x = block(x, edge_index) | |
| x = self.dropout(x) | |
| out = self.out_conv(x, edge_index) | |
| return out.view(-1) | |
| class GATv2Enhanced(nn.Module): | |
| def __init__(self, in_channels, hidden_channels=128, heads=8, num_blocks=2, dropout=0.5): | |
| super().__init__() | |
| layers = [] | |
| c_in = in_channels | |
| for _ in range(num_blocks): | |
| layers.append(ResidualGATBlock(c_in, hidden_channels, heads=heads, dropout=dropout, v2=True)) | |
| c_in = hidden_channels * heads | |
| self.blocks = nn.ModuleList(layers) | |
| self.dropout = nn.Dropout(dropout) | |
| self.out_conv = GATv2Conv(c_in, 1, heads=1, concat=False, dropout=dropout) | |
| def forward(self, x, edge_index): | |
| for block in self.blocks: | |
| x = block(x, edge_index) | |
| x = self.dropout(x) | |
| out = self.out_conv(x, edge_index) | |
| return out.view(-1) | |
| class AdapterWrapper(nn.Module): | |
| def __init__(self, in_dim_new, expected_in_dim, core_model): | |
| super().__init__() | |
| if in_dim_new != expected_in_dim: | |
| self.adapter = nn.Linear(in_dim_new, expected_in_dim, bias=True) | |
| else: | |
| self.adapter = None | |
| self.core = core_model | |
| def forward(self, x, edge_index): | |
| if self.adapter is not None: | |
| x = self.adapter(x) | |
| return self.core(x, edge_index) | |