import torch import torch.nn as nn import torch.nn.functional as F from utils.ops import GCN, GraphUnet, Initializer, norm_g class GNet(nn.Module): def __init__(self, in_dim, n_classes, args): super(GNet, self).__init__() self.n_act = getattr(nn, args.act_n)() self.c_act = getattr(nn, args.act_c)() self.s_gcn = GCN(in_dim, args.l_dim, self.n_act, args.drop_n) self.g_unet = GraphUnet( args.ks, args.l_dim, args.l_dim, args.l_dim, self.n_act, args.drop_n) self.out_l_1 = nn.Linear(3*args.l_dim*(args.l_num+1), args.h_dim) self.out_l_2 = nn.Linear(args.h_dim, n_classes) self.out_drop = nn.Dropout(p=args.drop_c) Initializer.weights_init(self) def forward(self, gs, hs, labels): hs = self.embed(gs, hs) logits = self.classify(hs) return self.metric(logits, labels) def embed(self, gs, hs): o_hs = [] for g, h in zip(gs, hs): h = self.embed_one(g, h) o_hs.append(h) hs = torch.stack(o_hs, 0) return hs def embed_one(self, g, h): g = norm_g(g) h = self.s_gcn(g, h) hs = self.g_unet(g, h) h = self.readout(hs) return h def readout(self, hs): h_max = [torch.max(h, 0)[0] for h in hs] h_sum = [torch.sum(h, 0) for h in hs] h_mean = [torch.mean(h, 0) for h in hs] h = torch.cat(h_max + h_sum + h_mean) return h def classify(self, h): h = self.out_drop(h) h = self.out_l_1(h) h = self.c_act(h) h = self.out_drop(h) h = self.out_l_2(h) return F.log_softmax(h, dim=1) def metric(self, logits, labels): loss = F.nll_loss(logits, labels) _, preds = torch.max(logits, 1) acc = torch.mean((preds == labels).float()) return loss, acc