Spaces:
Runtime error
Runtime error
| from turtle import forward | |
| import dgl | |
| import math | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import dgl.function as fn | |
| from dgl.nn.functional import edge_softmax | |
| from FairGNN.src.models.GCN import GCN | |
| from RHGN.layers import * | |
| from RHGN.layers import RHGNLayer | |
| class RHGN_adv(nn.Module): | |
| def __init__(self, G, node_dict, edge_dict, n_inp, n_hid, n_out, n_layers, n_heads, cid1_feature, cid2_feature, cid3_feature): | |
| super(RHGN_adv, self).__init__() | |
| self.cid1_feature = nn.Embedding(cid1_feature.size(0), cid1_feature.size(1)) | |
| self.cid1_feature.weight = nn.Parameter(cid1_feature) | |
| self.cid1_feature.weight.requires_grad = False | |
| self.cid2_feature = nn.Embedding(cid2_feature.size(0), cid2_feature.size(1)) | |
| self.cid2_feature.weight = nn.Parameter(cid2_feature) | |
| self.cid2_feature.weight.requires_grad = False | |
| self.cid3_feature= nn.Embedding(cid3_feature.size(0), cid3_feature.size(1)) | |
| self.cid3_feature.weight = nn.Parameter(cid3_feature) | |
| self.cid3_feature.weight.requires_grad = False | |
| self.adv_model = nn.Linear(n_hid, 1) # was n_out | |
| #self.sens_model = nn.Linear(64, 2) | |
| self.sens_model = GCN(200, 128, 1, 0.5) | |
| #self.optimizer_A = torch.optim.Adam(self.adv_model.parameters(), lr=0.1, weight_decay=1e-5) | |
| #self.A_loss = 0 | |
| def forward(self, h, inputs, G, blocks, out_key, label_key, is_train=True, print_flag=False): | |
| # h from orignal model | |
| #s = self.sens_model(h) | |
| inputs_new = inputs[0] | |
| print('graph:', G) | |
| s = self.sens_model(G, inputs_new) | |
| print('inputs:', inputs.shape) | |
| s_g = self.adv_model(h) | |
| print('s:', s.shape) | |
| print('s_g:', s_g.shape) | |
| return s, s_g | |
| class ali_RHGN(nn.Module): | |
| def __init__(self, G, node_dict, edge_dict, n_inp, n_hid, n_out, n_layers, n_heads,cid1_feature,cid2_feature,cid3_feature, use_norm = True): | |
| super(ali_RHGN, self).__init__() | |
| self.node_dict = node_dict | |
| self.edge_dict = edge_dict | |
| self.gcs = nn.ModuleList() | |
| self.n_inp = n_inp | |
| self.n_hid = n_hid | |
| self.n_out = n_out | |
| self.n_layers = n_layers | |
| self.adapt_ws = nn.ModuleList() | |
| for t in range(len(node_dict)): | |
| self.adapt_ws.append(nn.Linear(n_inp, n_hid)) | |
| for _ in range(n_layers): | |
| self.gcs.append(RHGNLayer(n_hid, n_hid, node_dict, edge_dict, n_heads, use_norm = use_norm)) | |
| self.out = nn.Linear(n_hid, n_out) | |
| self.cid1_feature= nn.Embedding(cid1_feature.size(0), cid1_feature.size(1)) | |
| self.cid1_feature.weight = nn.Parameter(cid1_feature) | |
| self.cid1_feature.weight.requires_grad = False | |
| self.cid2_feature= nn.Embedding(cid2_feature.size(0), cid2_feature.size(1)) | |
| self.cid2_feature.weight = nn.Parameter(cid2_feature) | |
| self.cid2_feature.weight.requires_grad = False | |
| self.cid3_feature= nn.Embedding(cid3_feature.size(0), cid3_feature.size(1)) | |
| self.cid3_feature.weight = nn.Parameter(cid3_feature) | |
| self.cid3_feature.weight.requires_grad = False | |
| self.excitation = nn.Sequential( | |
| nn.Linear(3, 32, bias=False), | |
| nn.ReLU(), | |
| nn.Linear(32, 3, bias=False), | |
| nn.ReLU() | |
| ) | |
| self.query = nn.Linear(200, n_inp) | |
| self.key = nn.Linear(200, n_inp) | |
| self.value = nn.Linear(200, n_inp) | |
| self.skip = nn.Parameter(torch.ones(1)) | |
| print('n_out:', self.n_out) | |
| #self.query_sens = nn.Linear(200, n_inp) | |
| #self.key_sens = nn.Linear(200, n_inp) | |
| #self.value_sens = nn.Linear(200, n_inp) | |
| #self.adv_model = nn.Linear(128, 1) | |
| #self.adv_model = nn.Linear(n_hid, n_out) | |
| #self.sens_model = GCN(95, 128, 1, 0.5) | |
| #self.sens_model = nn.Linear(n_hid, n_out) | |
| #self.sens_model2 = nn.Linear(n_inp, n_hid) | |
| #self.sens_model3 = nn.Linear(n_hid, n_out) | |
| #self.optimizer_A = torch.optim.Adam(self.adv_model.parameters(), lr=0.1, weight_decay=1e-5) | |
| #self.criterion = nn.BCEWithLogitsLoss() | |
| #self.optimizer_G = torch.optim.Adam(self.parameters()) | |
| #self.A_loss = 0 | |
| #self.G_loss = 0 | |
| #self.scheduler = torch.optim.lr_scheduler.OneCycleLR(self.optimizer_G, epochs=epochs, | |
| # steps_per_epoch=int(train_idx.shape[0]/batch_size)+1,max_lr = lr) | |
| def forward(self, input_nodes, output_nodes,blocks, out_key,label_key, is_train=True,print_flag=False): | |
| item_cid1=blocks[0].srcnodes['item'].data['cid1'].unsqueeze(1) #(N,1) | |
| cid1_feature = self.cid1_feature(item_cid1) # #(N,1,200) | |
| item_cid2=blocks[0].srcnodes['item'].data['cid2'].unsqueeze(1) #(N,1) | |
| cid2_feature = self.cid2_feature(item_cid2) # #(N,1,200) | |
| item_cid3=blocks[0].srcnodes['item'].data['cid3'].unsqueeze(1) #(N,1) | |
| cid3_feature = self.cid3_feature(item_cid3) # #(N,1,200) | |
| cid2_feature=cid1_feature | |
| cid3_feature=cid1_feature | |
| item_feature = blocks[0].srcnodes['item'].data['inp'] | |
| user_feature = blocks[0].srcnodes['user'].data['inp'] | |
| # brand_feature = blocks[0].srcnodes['brand'].data['inp'] | |
| inputs=torch.cat((cid1_feature,cid2_feature,cid3_feature),1) #(N,4,200) | |
| #print('inputs:', inputs.shape) # (455, 3, 200) | |
| k = self.key(inputs) #(N,4,n_inp) | |
| v = self.value(inputs) #(N,4,n_inp) | |
| q = self.query(item_feature.unsqueeze(-2)) #(N,1,n_inp) | |
| att_score = torch.einsum("bij,bjk->bik", k, q.transpose(1,2)) / math.sqrt(200) #(N,4,1) | |
| att_score = torch.softmax(att_score, axis=1) # (N,4,1) | |
| alpha = torch.sigmoid(self.skip) #(1,) | |
| temp = v * att_score #(N,4,n_inp) | |
| item_feature = alpha*(torch.mean(temp, dim=-2).squeeze(-2)) + (1-alpha)*item_feature # #(N,200) | |
| #print('item_feature:', item_feature) | |
| h = {} | |
| h['item']=F.gelu(self.adapt_ws[self.node_dict['item']](item_feature)) | |
| h['user']=F.gelu(self.adapt_ws[self.node_dict['user']](user_feature)) | |
| # h['brand']=F.gelu(self.adapt_ws[self.node_dict['brand']](brand_feature)) | |
| for i in range(self.n_layers): | |
| h = self.gcs[i](blocks[i], h, is_train=is_train,print_flag=print_flag) | |
| h = h[out_key] | |
| #print('h:', h) | |
| #self.adv_model.requires_grad_(False) | |
| #add sens model input | |
| #s = self.sens_model(inputs) | |
| #s = self.sens_model2(s) | |
| #s = self.sens_model3(s) | |
| #add adv model input | |
| #s_g = self.adv_model(h) | |
| h_new=self.out(h) | |
| #print('h_new:', h_new.shape) | |
| labels=blocks[-1].dstnodes[out_key].data[label_key] | |
| # h=F.log_softmax(h, dim=1) | |
| # return will be h, labels, and estimator output | |
| return h_new, labels | |
| class jd_RHGN(nn.Module): | |
| def __init__(self, G, node_dict, edge_dict, n_inp, n_hid, n_out, n_layers, n_heads, cid1_feature, cid2_feature, | |
| cid3_feature, cid4_feature, use_norm=True, ): | |
| super(jd_RHGN, self).__init__() | |
| self.node_dict = node_dict | |
| self.edge_dict = edge_dict | |
| self.gcs = nn.ModuleList() | |
| self.n_inp = n_inp | |
| self.n_hid = n_hid | |
| self.n_out = n_out | |
| self.n_layers = n_layers | |
| self.adapt_ws = nn.ModuleList() | |
| for t in range(len(node_dict)): | |
| self.adapt_ws.append(nn.Linear(n_inp, n_hid)) | |
| for _ in range(n_layers): | |
| self.gcs.append(RHGNLayer(n_hid, n_hid, node_dict, edge_dict, n_heads, use_norm=use_norm)) | |
| self.out = nn.Linear(n_hid, n_out) | |
| self.cid1_feature = nn.Embedding(cid1_feature.size(0), cid1_feature.size(1)) | |
| self.cid1_feature.weight = nn.Parameter(cid1_feature) | |
| self.cid1_feature.weight.requires_grad = False | |
| self.cid2_feature = nn.Embedding(cid2_feature.size(0), cid2_feature.size(1)) | |
| self.cid2_feature.weight = nn.Parameter(cid2_feature) | |
| self.cid2_feature.weight.requires_grad = False | |
| self.cid3_feature = nn.Embedding(cid3_feature.size(0), cid3_feature.size(1)) | |
| self.cid3_feature.weight = nn.Parameter(cid3_feature) | |
| self.cid3_feature.weight.requires_grad = False | |
| self.cid4_feature = nn.Embedding(cid4_feature.size(0), cid4_feature.size(1)) | |
| self.cid4_feature.weight = nn.Parameter(cid4_feature) | |
| self.cid4_feature.weight.requires_grad = False | |
| self.excitation = nn.Sequential( | |
| nn.Linear(4, 32, bias=False), | |
| nn.ReLU(), | |
| nn.Linear(32, 4, bias=False), | |
| nn.ReLU() | |
| ) | |
| self.query = nn.Linear(200, n_inp) | |
| self.key = nn.Linear(200, n_inp) | |
| self.value = nn.Linear(200, n_inp) | |
| self.skip = nn.Parameter(torch.ones(1)) | |
| self.l1=nn.Linear(200, n_inp) | |
| self.l2=nn.Linear(200, n_inp) | |
| self.l3=nn.Linear(200, n_inp) | |
| self.l4=nn.Linear(200, n_inp) | |
| def forward(self, input_nodes, output_nodes, blocks, out_key, label_key, is_train=True,print_flag=False): | |
| item_cid1 = blocks[0].srcnodes['item'].data['cid1'].unsqueeze(1) # (N,1) | |
| cid1_feature = self.cid1_feature(item_cid1) # #(N,1,200) | |
| #cid1_feature = self.l1(cid1_feature) | |
| item_cid2 = blocks[0].srcnodes['item'].data['cid2'].unsqueeze(1) # (N,1) | |
| cid2_feature = self.cid2_feature(item_cid2) # #(N,1,200) | |
| #cid2_feature = self.l2(cid2_feature) | |
| item_cid3 = blocks[0].srcnodes['item'].data['cid3'].unsqueeze(1) # (N,1) | |
| cid3_feature = self.cid3_feature(item_cid3) # #(N,1,200) | |
| #cid3_fature = self.l3(cid3_feature) | |
| # item_cid4 = blocks[0].srcnodes['item'].data['brand'].unsqueeze(1) # (N,1) | |
| # cid4_feature = self.cid4_feature(item_cid4) # #(N,1,200) | |
| #cid4_feature = self.l4(cid4_feature) | |
| cid2_feature=cid1_feature | |
| cid3_feature=cid1_feature | |
| # cid4_feature=cid1_feature | |
| item_feature = blocks[0].srcnodes['item'].data['inp'] | |
| user_feature = blocks[0].srcnodes['user'].data['inp'] | |
| # inputs = torch.cat((cid1_feature, cid2_feature, cid3_feature, cid4_feature), 1) # (N,4,200) | |
| inputs = torch.cat((cid1_feature, cid2_feature, cid3_feature), 1) # (N,3,200) | |
| k = self.key(inputs) # (N,3,200) | |
| v = self.value(inputs) # (N,3,200) | |
| q = self.query(item_feature.unsqueeze(-2)) # (N,1,32) | |
| att_score = torch.einsum("bij,bjk->bik", k, q.transpose(1, 2)) / math.sqrt(200) # (N,4,1) | |
| att_score = torch.softmax(att_score, axis=1) # (N,4,1) | |
| #Z = torch.mean(inputs, dim=-1, out=None) # (N,4) | |
| #A = self.excitation(Z).unsqueeze(-1) # (N,4,1) | |
| #att_score = att_score + A # (N,4,1) | |
| alpha = torch.sigmoid(self.skip) # (1,) | |
| temp = v * att_score # (N,4,200) | |
| item_feature = alpha * (torch.mean(temp, dim=-2).squeeze(-2)) + (1 - alpha) * item_feature # (N,200) | |
| h = {} | |
| h['item'] = F.gelu(self.adapt_ws[self.node_dict['item']](item_feature)) | |
| h['user'] = F.gelu(self.adapt_ws[self.node_dict['user']](user_feature)) | |
| for i in range(self.n_layers): | |
| h = self.gcs[i](blocks[i], h, is_train=is_train,print_flag=print_flag) | |
| h = h[out_key] | |
| h = self.out(h) | |
| labels = blocks[-1].dstnodes[out_key].data[label_key] | |
| # h=F.log_softmax(h, dim=1) | |
| return h, labels | |