MLGraph-Bitcoin-GAD / models.py
thanhphxu's picture
Upload folder using huggingface_hub
db886e4 verified
import torch
import torch.nn as nn
from torch_geometric.nn import GATConv, GATv2Conv, BatchNorm
class ResidualGATBlock(nn.Module):
def __init__(self, in_channels, hidden_channels, heads=8, dropout=0.5, v2=False):
super().__init__()
Conv = GATv2Conv if v2 else GATConv
self.conv = Conv(in_channels, hidden_channels, heads=heads, dropout=dropout)
self.bn = BatchNorm(hidden_channels * heads)
self.act = nn.ReLU()
self.dropout = nn.Dropout(dropout)
self.res_proj = None
out_dim = hidden_channels * heads
if in_channels != out_dim:
self.res_proj = nn.Linear(in_channels, out_dim)
def forward(self, x, edge_index):
identity = x
out = self.conv(x, edge_index)
out = self.bn(out)
out = self.act(out)
out = self.dropout(out)
if self.res_proj is not None:
identity = self.res_proj(identity)
return out + identity
class GATBaseline(nn.Module):
def __init__(self, in_channels, hidden_channels=128, heads=8, num_blocks=2, dropout=0.5):
super().__init__()
layers = []
c_in = in_channels
for _ in range(num_blocks):
layers.append(ResidualGATBlock(c_in, hidden_channels, heads=heads, dropout=dropout, v2=False))
c_in = hidden_channels * heads
self.blocks = nn.ModuleList(layers)
self.dropout = nn.Dropout(dropout)
self.out_conv = GATConv(c_in, 1, heads=1, concat=False, dropout=dropout)
def forward(self, x, edge_index):
for block in self.blocks:
x = block(x, edge_index)
x = self.dropout(x)
out = self.out_conv(x, edge_index)
return out.view(-1)
class GATv2Enhanced(nn.Module):
def __init__(self, in_channels, hidden_channels=128, heads=8, num_blocks=2, dropout=0.5):
super().__init__()
layers = []
c_in = in_channels
for _ in range(num_blocks):
layers.append(ResidualGATBlock(c_in, hidden_channels, heads=heads, dropout=dropout, v2=True))
c_in = hidden_channels * heads
self.blocks = nn.ModuleList(layers)
self.dropout = nn.Dropout(dropout)
self.out_conv = GATv2Conv(c_in, 1, heads=1, concat=False, dropout=dropout)
def forward(self, x, edge_index):
for block in self.blocks:
x = block(x, edge_index)
x = self.dropout(x)
out = self.out_conv(x, edge_index)
return out.view(-1)
class AdapterWrapper(nn.Module):
def __init__(self, in_dim_new, expected_in_dim, core_model):
super().__init__()
if in_dim_new != expected_in_dim:
self.adapter = nn.Linear(in_dim_new, expected_in_dim, bias=True)
else:
self.adapter = None
self.core = core_model
def forward(self, x, edge_index):
if self.adapter is not None:
x = self.adapter(x)
return self.core(x, edge_index)