| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from utils.ops import GCN, GraphUnet, Initializer, norm_g | |
| class GNet(nn.Module): | |
| def __init__(self, in_dim, n_classes, args): | |
| super(GNet, self).__init__() | |
| self.n_act = getattr(nn, args.act_n)() | |
| self.c_act = getattr(nn, args.act_c)() | |
| self.s_gcn = GCN(in_dim, args.l_dim, self.n_act, args.drop_n) | |
| self.g_unet = GraphUnet( | |
| args.ks, args.l_dim, args.l_dim, args.l_dim, self.n_act, | |
| args.drop_n) | |
| self.out_l_1 = nn.Linear(3*args.l_dim*(args.l_num+1), args.h_dim) | |
| self.out_l_2 = nn.Linear(args.h_dim, n_classes) | |
| self.out_drop = nn.Dropout(p=args.drop_c) | |
| Initializer.weights_init(self) | |
| def forward(self, gs, hs, labels): | |
| hs = self.embed(gs, hs) | |
| logits = self.classify(hs) | |
| return self.metric(logits, labels) | |
| def embed(self, gs, hs): | |
| o_hs = [] | |
| for g, h in zip(gs, hs): | |
| h = self.embed_one(g, h) | |
| o_hs.append(h) | |
| hs = torch.stack(o_hs, 0) | |
| return hs | |
| def embed_one(self, g, h): | |
| g = norm_g(g) | |
| h = self.s_gcn(g, h) | |
| hs = self.g_unet(g, h) | |
| h = self.readout(hs) | |
| return h | |
| def readout(self, hs): | |
| h_max = [torch.max(h, 0)[0] for h in hs] | |
| h_sum = [torch.sum(h, 0) for h in hs] | |
| h_mean = [torch.mean(h, 0) for h in hs] | |
| h = torch.cat(h_max + h_sum + h_mean) | |
| return h | |
| def classify(self, h): | |
| h = self.out_drop(h) | |
| h = self.out_l_1(h) | |
| h = self.c_act(h) | |
| h = self.out_drop(h) | |
| h = self.out_l_2(h) | |
| return F.log_softmax(h, dim=1) | |
| def metric(self, logits, labels): | |
| loss = F.nll_loss(logits, labels) | |
| _, preds = torch.max(logits, 1) | |
| acc = torch.mean((preds == labels).float()) | |
| return loss, acc | |