graphguard-backend / model.py
Bharateesha lvn
Deploy GraphGuard Backend V1
d511278
import torch
import torch.nn.functional as F
from torch_geometric.nn import GATv2Conv, global_mean_pool, global_max_pool
from torch.nn import Linear, Dropout
import config
class HeteroGAT(torch.nn.Module):
def __init__(self, hidden_dim=256, out_channels=5, num_edge_types=config.NUM_EDGE_TYPES):
super(HeteroGAT, self).__init__()
# Exact params from your training code
input_dim = 768
heads = 4
edge_dim = num_edge_types # Should be 6
# GAT Layer 1
# Matches: self.gat1 = GATv2Conv(input_dim, hidden_dim, heads=heads, ...)
self.gat1 = GATv2Conv(
input_dim, hidden_dim, heads=heads, edge_dim=edge_dim,
concat=True
)
# GAT Layer 2
# Matches: self.gat2 = GATv2Conv(hidden_dim * heads, hidden_dim, heads=1, ...)
# Note: heads=1 here is the specific fix!
self.gat2 = GATv2Conv(
hidden_dim * heads, hidden_dim, heads=1, edge_dim=edge_dim,
concat=False
)
# Classifier
# Matches: self.fc1 = torch.nn.Linear(hidden_dim * 2, hidden_dim)
# Input is hidden_dim * 2 because we concat MaxPool + MeanPool
self.fc1 = Linear(hidden_dim * 2, hidden_dim)
# Matches: self.fc2 = torch.nn.Linear(hidden_dim, output_dim)
self.fc2 = Linear(hidden_dim, out_channels)
self.dropout = Dropout(p=0.5)
def forward(self, x, edge_index, edge_attr, batch):
# Layer 1
x = self.gat1(x, edge_index, edge_attr=edge_attr)
x = F.elu(x)
# We don't need edge_dropout for inference, just standard dropout
x = self.dropout(x)
# Layer 2
x = self.gat2(x, edge_index, edge_attr=edge_attr)
x = F.elu(x)
x = self.dropout(x)
# Pooling (Global Max & Mean)
x_max = global_max_pool(x, batch)
x_mean = global_mean_pool(x, batch)
x = torch.cat([x_max, x_mean], dim=1)
# Classifier
x = self.fc1(x)
x = F.relu(x)
x = self.dropout(x)
x = self.fc2(x)
# Return Raw Logits (main.py handles the softmax)
return x