|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
import numpy as np, itertools, random, copy, math |
|
|
from transformers import BertModel, BertConfig |
|
|
from transformers import AutoTokenizer, AutoModelWithLMHead |
|
|
from model_utils import * |
|
|
|
|
|
|
|
|
class BertERC(nn.Module): |
|
|
|
|
|
def __init__(self, args, num_class): |
|
|
super().__init__() |
|
|
self.args = args |
|
|
|
|
|
|
|
|
self.dropout = nn.Dropout(args.dropout) |
|
|
|
|
|
self.bert_config = BertConfig.from_json_file(args.bert_model_dir + 'config.json') |
|
|
|
|
|
self.bert = BertModel.from_pretrained(args.home_dir + args.bert_model_dir, config = self.bert_config) |
|
|
in_dim = args.bert_dim |
|
|
|
|
|
|
|
|
layers = [nn.Linear(in_dim, args.hidden_dim), nn.ReLU()] |
|
|
for _ in range(args.mlp_layers- 1): |
|
|
layers += [nn.Linear(args.hidden_dim, args.hidden_dim), nn.ReLU()] |
|
|
layers += [nn.Linear(args.hidden_dim, num_class)] |
|
|
|
|
|
self.out_mlp = nn.Sequential(*layers) |
|
|
|
|
|
def forward(self, content_ids, token_types,utterance_len,seq_len): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lastHidden = self.bert(content_ids)[1] |
|
|
|
|
|
final_feature = self.dropout(lastHidden) |
|
|
|
|
|
|
|
|
|
|
|
outputs = self.out_mlp(final_feature) |
|
|
|
|
|
return outputs |
|
|
|
|
|
|
|
|
class DAGERC(nn.Module): |
|
|
|
|
|
def __init__(self, args, num_class): |
|
|
super().__init__() |
|
|
self.args = args |
|
|
|
|
|
|
|
|
self.dropout = nn.Dropout(args.dropout) |
|
|
|
|
|
self.gnn_layers = args.gnn_layers |
|
|
|
|
|
if not args.no_rel_attn: |
|
|
self.rel_emb = nn.Embedding(2,args.hidden_dim) |
|
|
self.rel_attn = True |
|
|
else: |
|
|
self.rel_attn = False |
|
|
|
|
|
if self.args.attn_type == 'linear': |
|
|
gats = [] |
|
|
for _ in range(args.gnn_layers): |
|
|
gats += [GatLinear(args.hidden_dim) if args.no_rel_attn else GatLinear_rel(args.hidden_dim)] |
|
|
self.gather = nn.ModuleList(gats) |
|
|
else: |
|
|
gats = [] |
|
|
for _ in range(args.gnn_layers): |
|
|
gats += [Gatdot(args.hidden_dim) if args.no_rel_attn else Gatdot_rel(args.hidden_dim)] |
|
|
self.gather = nn.ModuleList(gats) |
|
|
|
|
|
grus = [] |
|
|
for _ in range(args.gnn_layers): |
|
|
grus += [nn.GRUCell(args.hidden_dim, args.hidden_dim)] |
|
|
self.grus = nn.ModuleList(grus) |
|
|
|
|
|
self.fc1 = nn.Linear(args.emb_dim, args.hidden_dim) |
|
|
|
|
|
in_dim = args.hidden_dim * (args.gnn_layers + 1) + args.emb_dim |
|
|
|
|
|
layers = [nn.Linear(in_dim, args.hidden_dim), nn.ReLU()] |
|
|
for _ in range(args.mlp_layers - 1): |
|
|
layers += [nn.Linear(args.hidden_dim, args.hidden_dim), nn.ReLU()] |
|
|
layers += [nn.Linear(args.hidden_dim, num_class)] |
|
|
|
|
|
self.out_mlp = nn.Sequential(*layers) |
|
|
|
|
|
def forward(self, features, adj,s_mask): |
|
|
''' |
|
|
:param features: (B, N, D) |
|
|
:param adj: (B, N, N) |
|
|
:param s_mask: (B, N, N) |
|
|
:return: |
|
|
''' |
|
|
num_utter = features.size()[1] |
|
|
if self.rel_attn: |
|
|
rel_ft = self.rel_emb(s_mask) |
|
|
|
|
|
H0 = F.relu(self.fc1(features)) |
|
|
H = [H0] |
|
|
for l in range(self.args.gnn_layers): |
|
|
H1 = self.grus[l](H[l][:,0,:]).unsqueeze(1) |
|
|
for i in range(1, num_utter): |
|
|
if not self.rel_attn: |
|
|
_, M = self.gather[l](H[l][:,i,:], H1, H1, adj[:,i,:i]) |
|
|
else: |
|
|
_, M = self.gather[l](H[l][:, i, :], H1, H1, adj[:, i, :i], rel_ft[:, i, :i, :]) |
|
|
H1 = torch.cat((H1 , self.grus[l](H[l][:,i,:], M).unsqueeze(1)), dim = 1) |
|
|
|
|
|
|
|
|
H.append(H1) |
|
|
H0 = H1 |
|
|
H.append(features) |
|
|
H = torch.cat(H, dim = 2) |
|
|
logits = self.out_mlp(H) |
|
|
return logits |
|
|
|
|
|
|
|
|
|
|
|
class DAGERC_fushion(nn.Module): |
|
|
|
|
|
def __init__(self, args, num_class): |
|
|
super().__init__() |
|
|
self.args = args |
|
|
|
|
|
|
|
|
self.dropout = nn.Dropout(args.dropout) |
|
|
|
|
|
self.gnn_layers = args.gnn_layers |
|
|
|
|
|
if not args.no_rel_attn: |
|
|
self.rel_attn = True |
|
|
else: |
|
|
self.rel_attn = False |
|
|
|
|
|
if self.args.attn_type == 'linear': |
|
|
gats = [] |
|
|
for _ in range(args.gnn_layers): |
|
|
gats += [GatLinear(args.hidden_dim) if args.no_rel_attn else GatLinear_rel(args.hidden_dim)] |
|
|
self.gather = nn.ModuleList(gats) |
|
|
elif self.args.attn_type == 'dotprod': |
|
|
gats = [] |
|
|
for _ in range(args.gnn_layers): |
|
|
gats += [GatDot(args.hidden_dim) if args.no_rel_attn else GatDot_rel(args.hidden_dim)] |
|
|
self.gather = nn.ModuleList(gats) |
|
|
elif self.args.attn_type == 'rgcn': |
|
|
gats = [] |
|
|
for _ in range(args.gnn_layers): |
|
|
|
|
|
gats += [GAT_dialoggcn_v1(args.hidden_dim)] |
|
|
self.gather = nn.ModuleList(gats) |
|
|
|
|
|
grus_c = [] |
|
|
for _ in range(args.gnn_layers): |
|
|
grus_c += [nn.GRUCell(args.hidden_dim, args.hidden_dim)] |
|
|
self.grus_c = nn.ModuleList(grus_c) |
|
|
|
|
|
grus_p = [] |
|
|
for _ in range(args.gnn_layers): |
|
|
grus_p += [nn.GRUCell(args.hidden_dim, args.hidden_dim)] |
|
|
self.grus_p = nn.ModuleList(grus_p) |
|
|
|
|
|
fcs = [] |
|
|
for _ in range(args.gnn_layers): |
|
|
fcs += [nn.Linear(args.hidden_dim * 2, args.hidden_dim)] |
|
|
self.fcs = nn.ModuleList(fcs) |
|
|
|
|
|
self.fc1 = nn.Linear(args.emb_dim, args.hidden_dim) |
|
|
|
|
|
self.nodal_att_type = args.nodal_att_type |
|
|
|
|
|
in_dim = args.hidden_dim * (args.gnn_layers + 1) + args.emb_dim |
|
|
|
|
|
|
|
|
layers = [nn.Linear(in_dim, args.hidden_dim), nn.ReLU()] |
|
|
for _ in range(args.mlp_layers - 1): |
|
|
layers += [nn.Linear(args.hidden_dim, args.hidden_dim), nn.ReLU()] |
|
|
layers += [self.dropout] |
|
|
layers += [nn.Linear(args.hidden_dim, num_class)] |
|
|
|
|
|
self.out_mlp = nn.Sequential(*layers) |
|
|
|
|
|
self.attentive_node_features = attentive_node_features(in_dim) |
|
|
|
|
|
def forward(self, features, adj,s_mask,s_mask_onehot, lengths): |
|
|
''' |
|
|
:param features: (B, N, D) |
|
|
:param adj: (B, N, N) |
|
|
:param s_mask: (B, N, N) |
|
|
:param s_mask_onehot: (B, N, N, 2) |
|
|
:return: |
|
|
''' |
|
|
num_utter = features.size()[1] |
|
|
|
|
|
H0 = F.relu(self.fc1(features)) |
|
|
|
|
|
H = [H0] |
|
|
for l in range(self.args.gnn_layers): |
|
|
C = self.grus_c[l](H[l][:,0,:]).unsqueeze(1) |
|
|
M = torch.zeros_like(C).squeeze(1) |
|
|
|
|
|
P = self.grus_p[l](M, H[l][:,0,:]).unsqueeze(1) |
|
|
|
|
|
|
|
|
H1 = C+P |
|
|
for i in range(1, num_utter): |
|
|
|
|
|
if self.args.attn_type == 'rgcn': |
|
|
_, M = self.gather[l](H[l][:,i,:], H1, H1, adj[:,i,:i], s_mask[:,i,:i]) |
|
|
|
|
|
else: |
|
|
if not self.rel_attn: |
|
|
_, M = self.gather[l](H[l][:,i,:], H1, H1, adj[:,i,:i]) |
|
|
else: |
|
|
_, M = self.gather[l](H[l][:,i,:], H1, H1, adj[:,i,:i], s_mask[:, i, :i]) |
|
|
|
|
|
C = self.grus_c[l](H[l][:,i,:], M).unsqueeze(1) |
|
|
P = self.grus_p[l](M, H[l][:,i,:]).unsqueeze(1) |
|
|
|
|
|
|
|
|
|
|
|
H_temp = C+P |
|
|
H1 = torch.cat((H1 , H_temp), dim = 1) |
|
|
|
|
|
|
|
|
H.append(H1) |
|
|
H.append(features) |
|
|
|
|
|
H = torch.cat(H, dim = 2) |
|
|
|
|
|
H = self.attentive_node_features(H,lengths,self.nodal_att_type) |
|
|
|
|
|
logits = self.out_mlp(H) |
|
|
|
|
|
return logits |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DAGERC_new_1(nn.Module): |
|
|
|
|
|
def __init__(self, args, num_class): |
|
|
super().__init__() |
|
|
self.args = args |
|
|
|
|
|
|
|
|
self.dropout = nn.Dropout(args.dropout) |
|
|
|
|
|
self.gnn_layers = args.gnn_layers |
|
|
|
|
|
if not args.no_rel_attn: |
|
|
self.rel_attn = True |
|
|
else: |
|
|
self.rel_attn = False |
|
|
|
|
|
if self.args.attn_type == 'linear': |
|
|
gats = [] |
|
|
for _ in range(args.gnn_layers): |
|
|
gats += [GatLinear(args.hidden_dim) if args.no_rel_attn else GatLinear_rel(args.hidden_dim)] |
|
|
self.gather = nn.ModuleList(gats) |
|
|
elif self.args.attn_type == 'dotprod': |
|
|
gats = [] |
|
|
for _ in range(args.gnn_layers): |
|
|
gats += [GatDot(args.hidden_dim) if args.no_rel_attn else GatDot_rel(args.hidden_dim)] |
|
|
self.gather = nn.ModuleList(gats) |
|
|
elif self.args.attn_type == 'rgcn': |
|
|
|
|
|
gats_short = [] |
|
|
gats_long = [] |
|
|
for _ in range(args.gnn_layers): |
|
|
gats_short += [GAT_dialoggcn_v1(args.hidden_dim)] |
|
|
for _ in range(args.gnn_layers): |
|
|
gats_long += [GAT_dialoggcn_v1(args.hidden_dim)] |
|
|
self.gather_short = nn.ModuleList(gats_short) |
|
|
self.gather_long = nn.ModuleList(gats_long) |
|
|
|
|
|
|
|
|
grus_c_short = [] |
|
|
for _ in range(args.gnn_layers): |
|
|
grus_c_short += [nn.GRUCell(args.hidden_dim, args.hidden_dim)] |
|
|
self.grus_c_short = nn.ModuleList(grus_c_short) |
|
|
|
|
|
|
|
|
grus_c_long = [] |
|
|
for _ in range(args.gnn_layers): |
|
|
grus_c_long += [nn.GRUCell(args.hidden_dim, args.hidden_dim)] |
|
|
self.grus_c_long = nn.ModuleList(grus_c_long) |
|
|
|
|
|
grus_p_short = [] |
|
|
for _ in range(args.gnn_layers): |
|
|
grus_p_short += [nn.GRUCell(args.hidden_dim, args.hidden_dim)] |
|
|
self.grus_p_short = nn.ModuleList(grus_p_short) |
|
|
|
|
|
grus_p_long = [] |
|
|
for _ in range(args.gnn_layers): |
|
|
grus_p_long += [nn.GRUCell(args.hidden_dim, args.hidden_dim)] |
|
|
self.grus_p_long = nn.ModuleList(grus_p_long) |
|
|
|
|
|
|
|
|
fcs_short = [] |
|
|
for _ in range(args.gnn_layers): |
|
|
fcs_short += [nn.Linear(args.hidden_dim * 2, args.hidden_dim)] |
|
|
self.fcs_short = nn.ModuleList(fcs_short) |
|
|
|
|
|
|
|
|
fcs_long = [] |
|
|
for _ in range(args.gnn_layers): |
|
|
fcs_long += [nn.Linear(args.hidden_dim * 2, args.hidden_dim)] |
|
|
self.fcs_long = nn.ModuleList(fcs_long) |
|
|
|
|
|
|
|
|
self.fc1 = nn.Linear(args.emb_dim, args.hidden_dim) |
|
|
|
|
|
self.nodal_att_type = args.nodal_att_type |
|
|
|
|
|
in_dim = ((args.hidden_dim*2)+ args.emb_dim) |
|
|
|
|
|
|
|
|
layers = [nn.Linear(in_dim, args.hidden_dim), nn.ReLU()] |
|
|
for _ in range(args.mlp_layers - 1): |
|
|
layers += [nn.Linear(args.hidden_dim, args.hidden_dim), nn.ReLU()] |
|
|
layers += [self.dropout] |
|
|
layers += [nn.Linear(args.hidden_dim, num_class)] |
|
|
|
|
|
self.out_mlp = nn.Sequential(*layers) |
|
|
|
|
|
self.attentive_node_features = attentive_node_features(in_dim) |
|
|
|
|
|
self.affine1 = nn.Parameter(torch.empty(size=((args.hidden_dim) , (args.hidden_dim) ))) |
|
|
nn.init.xavier_uniform_(self.affine1.data, gain=1.414) |
|
|
self.affine2 = nn.Parameter(torch.empty(size=((args.hidden_dim) , (args.hidden_dim) ))) |
|
|
nn.init.xavier_uniform_(self.affine2.data, gain=1.414) |
|
|
|
|
|
self.diff_loss = DiffLoss(args) |
|
|
self.beta = args.diffloss |
|
|
|
|
|
def forward(self, features, adj_1, adj_2 ,s_mask, s_mask_onehot, lengths): |
|
|
|
|
|
are_equal = all(torch.equal(h1, h2) for h1, h2 in zip(adj_1, adj_2)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
num_utter = features.size()[1] |
|
|
|
|
|
H0 = F.relu(self.fc1(features)) |
|
|
|
|
|
|
|
|
H = [H0] |
|
|
H_combined_short_list = [] |
|
|
|
|
|
for l in range(self.args.gnn_layers): |
|
|
C = self.grus_c_short[l](H[l][:,0,:]).unsqueeze(1) |
|
|
M = torch.zeros_like(C).squeeze(1) |
|
|
|
|
|
P = self.grus_p_short[l](M, H[l][:,0,:]).unsqueeze(1) |
|
|
|
|
|
|
|
|
H1 = C+P |
|
|
for i in range(1, num_utter): |
|
|
|
|
|
if self.args.attn_type == 'rgcn': |
|
|
|
|
|
|
|
|
_, M = self.gather_short[l](H[l][:,i,:], H1, H1, adj_1[:,i,:i], s_mask[:,i,:i]) |
|
|
|
|
|
else: |
|
|
if not self.rel_attn: |
|
|
_, M = self.gather_short[l](H[l][:,i,:], H1, H1, adj_1[:,i,:i]) |
|
|
else: |
|
|
_, M = self.gather_short[l](H[l][:,i,:], H1, H1, adj_1[:,i,:i], s_mask[:, i, :i]) |
|
|
|
|
|
|
|
|
|
|
|
C = self.grus_c_short[l](H[l][:,i,:], M).unsqueeze(1) |
|
|
|
|
|
P = self.grus_p_short[l](M, H[l][:,i,:]).unsqueeze(1) |
|
|
|
|
|
|
|
|
|
|
|
H_temp = C+P |
|
|
H1 = torch.cat((H1 , H_temp), dim = 1) |
|
|
|
|
|
|
|
|
H.append(H1) |
|
|
H_combined_short_list.append(H[l+1]) |
|
|
''' |
|
|
下面对长距离特征进行处理 The following processes the long-distance features. |
|
|
''' |
|
|
H_long = [H0] |
|
|
H_combined_long_list = [] |
|
|
|
|
|
|
|
|
for l in range(self.args.gnn_layers): |
|
|
C_long = self.grus_c_long[l](H_long[l][:,0,:]).unsqueeze(1) |
|
|
M_long = torch.zeros_like(C_long).squeeze(1) |
|
|
P_long = self.grus_p_long[l](M_long, H_long[l][:,0,:]).unsqueeze(1) |
|
|
|
|
|
H1_long = C_long + P_long |
|
|
for i in range(1, num_utter): |
|
|
|
|
|
if self.args.attn_type == 'rgcn': |
|
|
_, M_long = self.gather_long[l](H_long[l][:,i,:], H1_long, H1_long, adj_2[:,i,:i], s_mask[:,i,:i]) |
|
|
else: |
|
|
if not self.rel_attn: |
|
|
_, M_long = self.gather_long[l](H_long[l][:,i,:], H1_long, H1_long, adj_2[:,i,:i]) |
|
|
else: |
|
|
_, M_long = self.gather_long[l](H_long[l][:,i,:], H1_long, H1_long, adj_2[:,i,:i], s_mask[:,i,:i]) |
|
|
|
|
|
|
|
|
C_long = self.grus_c_long[l](H_long[l][:,i,:], M_long).unsqueeze(1) |
|
|
P_long = self.grus_p_long[l](M_long, H_long[l][:,i,:]).unsqueeze(1) |
|
|
|
|
|
H_temp_long = C_long + P_long |
|
|
H1_long = torch.cat((H1_long, H_temp_long), dim=1) |
|
|
H_long.append(H1_long) |
|
|
H_combined_long_list.append(H_long[l+1]) |
|
|
|
|
|
''' |
|
|
两个通道特征都提取完毕! Both short- and long-distance channel features have been extracted! |
|
|
''' |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
H_final = [] |
|
|
|
|
|
|
|
|
diff_loss = 0 |
|
|
for l in range(self.args.gnn_layers): |
|
|
|
|
|
HShort_prime = H_combined_short_list[l] |
|
|
HLong_prime = H_combined_long_list[l] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
diff_loss = self.diff_loss(HShort_prime, HLong_prime) + diff_loss |
|
|
|
|
|
|
|
|
|
|
|
A1 = F.softmax(torch.bmm(torch.matmul(HShort_prime, self.affine1), torch.transpose(HLong_prime, 1, 2)), dim=2) |
|
|
A2 = F.softmax(torch.bmm(torch.matmul(HLong_prime, self.affine2), torch.transpose(HShort_prime, 1, 2)), dim=2) |
|
|
|
|
|
HShort_prime_new = torch.bmm(A1, HLong_prime) |
|
|
HLong_prime_new = torch.bmm(A2, HShort_prime) |
|
|
|
|
|
HShort_prime_out = self.dropout(HShort_prime_new) if l < self.args.gnn_layers - 1 else HShort_prime_new |
|
|
HLong_prime_out = self.dropout(HLong_prime_new) if l <self.args.gnn_layers - 1 else HLong_prime_new |
|
|
|
|
|
H_final.append(HShort_prime_out) |
|
|
H_final.append(HLong_prime_out) |
|
|
H_final.append(features) |
|
|
|
|
|
H_final = torch.cat([H_final[-3],H_final[-2],H_final[-1]], dim = 2) |
|
|
|
|
|
|
|
|
|
|
|
H_final = self.attentive_node_features(H_final,lengths,self.nodal_att_type) |
|
|
|
|
|
logits = self.out_mlp(H_final) |
|
|
|
|
|
return logits, self.beta * diff_loss |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DAGERC_new_2(nn.Module): |
|
|
|
|
|
def __init__(self, args, num_class): |
|
|
super().__init__() |
|
|
self.args = args |
|
|
|
|
|
|
|
|
self.dropout = nn.Dropout(args.dropout) |
|
|
|
|
|
self.gnn_layers = args.gnn_layers |
|
|
|
|
|
if not args.no_rel_attn: |
|
|
self.rel_attn = True |
|
|
else: |
|
|
self.rel_attn = False |
|
|
|
|
|
if self.args.attn_type == 'linear': |
|
|
gats = [] |
|
|
for _ in range(args.gnn_layers): |
|
|
gats += [GatLinear(args.hidden_dim) if args.no_rel_attn else GatLinear_rel(args.hidden_dim)] |
|
|
self.gather = nn.ModuleList(gats) |
|
|
elif self.args.attn_type == 'dotprod': |
|
|
gats = [] |
|
|
for _ in range(args.gnn_layers): |
|
|
gats += [GatDot(args.hidden_dim) if args.no_rel_attn else GatDot_rel(args.hidden_dim)] |
|
|
self.gather = nn.ModuleList(gats) |
|
|
elif self.args.attn_type == 'rgcn': |
|
|
|
|
|
gats_short = [] |
|
|
gats_long = [] |
|
|
for _ in range(args.gnn_layers): |
|
|
gats_short += [GAT_dialoggcn_v1(args.hidden_dim)] |
|
|
for _ in range(args.gnn_layers): |
|
|
gats_long += [GAT_dialoggcn_v1(args.hidden_dim)] |
|
|
self.gather_short = nn.ModuleList(gats_short) |
|
|
self.gather_long = nn.ModuleList(gats_long) |
|
|
|
|
|
|
|
|
grus_c_short = [] |
|
|
for _ in range(args.gnn_layers): |
|
|
grus_c_short += [nn.GRUCell(args.hidden_dim, args.hidden_dim)] |
|
|
self.grus_c_short = nn.ModuleList(grus_c_short) |
|
|
|
|
|
|
|
|
grus_c_long = [] |
|
|
for _ in range(args.gnn_layers): |
|
|
grus_c_long += [nn.GRUCell(args.hidden_dim, args.hidden_dim)] |
|
|
self.grus_c_long = nn.ModuleList(grus_c_long) |
|
|
|
|
|
grus_p_short = [] |
|
|
for _ in range(args.gnn_layers): |
|
|
grus_p_short += [nn.GRUCell(args.hidden_dim, args.hidden_dim)] |
|
|
self.grus_p_short = nn.ModuleList(grus_p_short) |
|
|
|
|
|
grus_p_long = [] |
|
|
for _ in range(args.gnn_layers): |
|
|
grus_p_long += [nn.GRUCell(args.hidden_dim, args.hidden_dim)] |
|
|
self.grus_p_long = nn.ModuleList(grus_p_long) |
|
|
|
|
|
|
|
|
fcs_short = [] |
|
|
for _ in range(args.gnn_layers): |
|
|
fcs_short += [nn.Linear(args.hidden_dim * 2, args.hidden_dim)] |
|
|
self.fcs_short = nn.ModuleList(fcs_short) |
|
|
|
|
|
|
|
|
fcs_long = [] |
|
|
for _ in range(args.gnn_layers): |
|
|
fcs_long += [nn.Linear(args.hidden_dim * 2, args.hidden_dim)] |
|
|
self.fcs_long = nn.ModuleList(fcs_long) |
|
|
|
|
|
|
|
|
self.fc1 = nn.Linear(args.emb_dim, args.hidden_dim) |
|
|
|
|
|
self.nodal_att_type = args.nodal_att_type |
|
|
|
|
|
in_dim = ((args.hidden_dim*2)*2 + args.emb_dim) |
|
|
|
|
|
|
|
|
layers = [nn.Linear(in_dim, args.hidden_dim), nn.ReLU()] |
|
|
for _ in range(args.mlp_layers - 1): |
|
|
layers += [nn.Linear(args.hidden_dim, args.hidden_dim), nn.ReLU()] |
|
|
layers += [self.dropout] |
|
|
layers += [nn.Linear(args.hidden_dim, num_class)] |
|
|
|
|
|
self.out_mlp = nn.Sequential(*layers) |
|
|
|
|
|
self.attentive_node_features = attentive_node_features(in_dim) |
|
|
|
|
|
self.affine1 = nn.Parameter(torch.empty(size=((args.hidden_dim*2) , (args.hidden_dim*2) ))) |
|
|
nn.init.xavier_uniform_(self.affine1.data, gain=1.414) |
|
|
self.affine2 = nn.Parameter(torch.empty(size=((args.hidden_dim*2) , (args.hidden_dim*2) ))) |
|
|
nn.init.xavier_uniform_(self.affine2.data, gain=1.414) |
|
|
|
|
|
self.diff_loss = DiffLoss(args) |
|
|
self.beta = args.diffloss |
|
|
|
|
|
def forward(self, features, adj_1, adj_2 ,s_mask, s_mask_onehot, lengths): |
|
|
|
|
|
are_equal = all(torch.equal(h1, h2) for h1, h2 in zip(adj_1, adj_2)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
num_utter = features.size()[1] |
|
|
|
|
|
H0 = F.relu(self.fc1(features)) |
|
|
|
|
|
|
|
|
H = [H0] |
|
|
H_combined_short_list = [] |
|
|
|
|
|
for l in range(self.args.gnn_layers): |
|
|
C = self.grus_c_short[l](H[l][:,0,:]).unsqueeze(1) |
|
|
M = torch.zeros_like(C).squeeze(1) |
|
|
|
|
|
P = self.grus_p_short[l](M, H[l][:,0,:]).unsqueeze(1) |
|
|
|
|
|
|
|
|
H1 = C+P |
|
|
for i in range(1, num_utter): |
|
|
|
|
|
if self.args.attn_type == 'rgcn': |
|
|
|
|
|
|
|
|
_, M = self.gather_short[l](H[l][:,i,:], H1, H1, adj_1[:,i,:i], s_mask[:,i,:i]) |
|
|
|
|
|
else: |
|
|
if not self.rel_attn: |
|
|
_, M = self.gather_short[l](H[l][:,i,:], H1, H1, adj_1[:,i,:i]) |
|
|
else: |
|
|
_, M = self.gather_short[l](H[l][:,i,:], H1, H1, adj_1[:,i,:i], s_mask[:, i, :i]) |
|
|
|
|
|
|
|
|
|
|
|
C = self.grus_c_short[l](H[l][:,i,:], M).unsqueeze(1) |
|
|
|
|
|
P = self.grus_p_short[l](M, H[l][:,i,:]).unsqueeze(1) |
|
|
|
|
|
|
|
|
|
|
|
H_temp = C+P |
|
|
H1 = torch.cat((H1 , H_temp), dim = 1) |
|
|
|
|
|
|
|
|
H.append(H1) |
|
|
|
|
|
|
|
|
|
|
|
features_reversed = torch.flip(features, dims=[1]) |
|
|
adj_reversed = torch.flip(adj_1, dims=[1, 2]) |
|
|
s_mask_reversed = torch.flip(s_mask, dims=[1, 2]) |
|
|
|
|
|
H0_reversed = F.relu(self.fc1(features_reversed)) |
|
|
H_reversed = [H0_reversed] |
|
|
|
|
|
for l in range(self.args.gnn_layers): |
|
|
C = self.grus_c_short[l](H_reversed[l][:, 0, :]).unsqueeze(1) |
|
|
M = torch.zeros_like(C).squeeze(1) |
|
|
P = self.grus_p_short[l](M, H_reversed[l][:, 0, :]).unsqueeze(1) |
|
|
H1_reversed = C + P |
|
|
|
|
|
for i in range(1, num_utter): |
|
|
if self.args.attn_type == 'rgcn': |
|
|
_, M = self.gather_short[l](H_reversed[l][:, i, :], H1_reversed, H1_reversed, adj_reversed[:, i, :i], s_mask_reversed[:, i, :i]) |
|
|
else: |
|
|
if not self.rel_attn: |
|
|
_, M = self.gather_short[l](H_reversed[l][:, i, :], H1_reversed, H1_reversed, adj_reversed[:, i, :i]) |
|
|
else: |
|
|
_, M = self.gather_short[l](H_reversed[l][:, i, :], H1_reversed, H1_reversed, adj_reversed[:, i, :i], s_mask_reversed[:, i, :i]) |
|
|
|
|
|
C = self.grus_c_short[l](H_reversed[l][:, i, :], M).unsqueeze(1) |
|
|
P = self.grus_p_short[l](M, H_reversed[l][:, i, :]).unsqueeze(1) |
|
|
H_temp_reversed = C + P |
|
|
H1_reversed = torch.cat((H1_reversed, H_temp_reversed), dim=1) |
|
|
H_reversed.append(H1_reversed) |
|
|
H_combined = torch.cat((H[l+1], H_reversed[l+1]), dim=2) |
|
|
H_combined_short_list.append(H_combined) |
|
|
|
|
|
''' |
|
|
下面对长距离特征进行处理 The following processes the long-distance features. |
|
|
''' |
|
|
H_long = [H0] |
|
|
H_combined_long_list = [] |
|
|
|
|
|
|
|
|
for l in range(self.args.gnn_layers): |
|
|
C_long = self.grus_c_long[l](H_long[l][:,0,:]).unsqueeze(1) |
|
|
M_long = torch.zeros_like(C_long).squeeze(1) |
|
|
P_long = self.grus_p_long[l](M_long, H_long[l][:,0,:]).unsqueeze(1) |
|
|
|
|
|
H1_long = C_long + P_long |
|
|
for i in range(1, num_utter): |
|
|
|
|
|
if self.args.attn_type == 'rgcn': |
|
|
_, M_long = self.gather_long[l](H_long[l][:,i,:], H1_long, H1_long, adj_2[:,i,:i], s_mask[:,i,:i]) |
|
|
else: |
|
|
if not self.rel_attn: |
|
|
_, M_long = self.gather_long[l](H_long[l][:,i,:], H1_long, H1_long, adj_2[:,i,:i]) |
|
|
else: |
|
|
_, M_long = self.gather_long[l](H_long[l][:,i,:], H1_long, H1_long, adj_2[:,i,:i], s_mask[:,i,:i]) |
|
|
|
|
|
|
|
|
C_long = self.grus_c_long[l](H_long[l][:,i,:], M_long).unsqueeze(1) |
|
|
P_long = self.grus_p_long[l](M_long, H_long[l][:,i,:]).unsqueeze(1) |
|
|
|
|
|
H_temp_long = C_long + P_long |
|
|
H1_long = torch.cat((H1_long, H_temp_long), dim=1) |
|
|
H_long.append(H1_long) |
|
|
|
|
|
|
|
|
features_reversed_long = torch.flip(features, dims=[1]) |
|
|
adj_reversed_long = torch.flip(adj_2, dims=[1, 2]) |
|
|
s_mask_reversed_long = torch.flip(s_mask, dims=[1, 2]) |
|
|
|
|
|
H0_reversed_long = F.relu(self.fc1(features_reversed_long)) |
|
|
H_reversed_long = [H0_reversed_long] |
|
|
|
|
|
for l in range(self.args.gnn_layers): |
|
|
C_long = self.grus_c_long[l](H_reversed_long[l][:, 0, :]).unsqueeze(1) |
|
|
M_long = torch.zeros_like(C_long).squeeze(1) |
|
|
P_long = self.grus_p_long[l](M_long, H_reversed_long[l][:, 0, :]).unsqueeze(1) |
|
|
H1_reversed_long = C_long + P_long |
|
|
|
|
|
for i in range(1, num_utter): |
|
|
if self.args.attn_type == 'rgcn': |
|
|
_, M_long = self.gather_long[l](H_reversed_long[l][:, i, :], H1_reversed_long, H1_reversed_long, adj_reversed_long[:, i, :i], s_mask_reversed_long[:, i, :i]) |
|
|
else: |
|
|
if not self.rel_attn: |
|
|
_, M_long = self.gather_long[l](H_reversed_long[l][:, i, :], H1_reversed_long, H1_reversed_long, adj_reversed_long[:, i, :i]) |
|
|
else: |
|
|
_, M_long = self.gather_long[l](H_reversed_long[l][:, i, :], H1_reversed_long, H1_reversed_long, adj_reversed_long[:, i, :i], s_mask_reversed_long[:, i, :i]) |
|
|
|
|
|
C_long = self.grus_c_long[l](H_reversed_long[l][:, i, :], M_long).unsqueeze(1) |
|
|
P_long = self.grus_p_long[l](M_long, H_reversed_long[l][:, i, :]).unsqueeze(1) |
|
|
H_temp_reversed_long = C_long + P_long |
|
|
H1_reversed_long = torch.cat((H1_reversed_long, H_temp_reversed_long), dim=1) |
|
|
H_reversed_long.append(H1_reversed_long) |
|
|
|
|
|
|
|
|
H_combined_long = torch.cat((H_long[l+1], H_reversed_long[l+1]), dim=2) |
|
|
H_combined_long_list.append(H_combined_long) |
|
|
|
|
|
''' |
|
|
两个通道特征都提取完毕! Both short- and long-distance channel features have been extracted! |
|
|
''' |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
H_final = [] |
|
|
|
|
|
|
|
|
diff_loss = 0 |
|
|
for l in range(self.args.gnn_layers): |
|
|
|
|
|
HShort_prime = H_combined_short_list[l] |
|
|
HLong_prime = H_combined_long_list[l] |
|
|
print("HShort_prime:", HShort_prime) |
|
|
print("HLong_prime:", HLong_prime) |
|
|
print("HShort_prime shape:", HShort_prime.shape) |
|
|
print("HLong_prime shape:", HLong_prime.shape) |
|
|
diff_loss = self.diff_loss(HShort_prime, HLong_prime) + diff_loss |
|
|
|
|
|
|
|
|
|
|
|
A1 = F.softmax(torch.bmm(torch.matmul(HShort_prime, self.affine1), torch.transpose(HLong_prime, 1, 2)), dim=2) |
|
|
A2 = F.softmax(torch.bmm(torch.matmul(HLong_prime, self.affine2), torch.transpose(HShort_prime, 1, 2)), dim=2) |
|
|
|
|
|
HShort_prime_new = torch.bmm(A1, HLong_prime) |
|
|
HLong_prime_new = torch.bmm(A2, HShort_prime) |
|
|
|
|
|
HShort_prime_out = self.dropout(HShort_prime_new) if l < self.args.gnn_layers - 1 else HShort_prime_new |
|
|
HLong_prime_out = self.dropout(HLong_prime_new) if l <self.args.gnn_layers - 1 else HLong_prime_new |
|
|
|
|
|
H_final.append(HShort_prime_out) |
|
|
H_final.append(HLong_prime_out) |
|
|
H_final.append(features) |
|
|
|
|
|
H_final = torch.cat([H_final[-3],H_final[-2],H_final[-1]], dim = 2) |
|
|
|
|
|
|
|
|
|
|
|
H_final = self.attentive_node_features(H_final,lengths,self.nodal_att_type) |
|
|
|
|
|
logits = self.out_mlp(H_final) |
|
|
|
|
|
return logits, self.beta * diff_loss |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DAGERC_new_3(nn.Module): |
|
|
|
|
|
def __init__(self, args, num_class): |
|
|
super().__init__() |
|
|
self.args = args |
|
|
|
|
|
|
|
|
self.dropout = nn.Dropout(args.dropout) |
|
|
|
|
|
self.gnn_layers = args.gnn_layers |
|
|
|
|
|
if not args.no_rel_attn: |
|
|
self.rel_attn = True |
|
|
else: |
|
|
self.rel_attn = False |
|
|
|
|
|
if self.args.attn_type == 'linear': |
|
|
gats = [] |
|
|
for _ in range(args.gnn_layers): |
|
|
gats += [GatLinear(args.hidden_dim) if args.no_rel_attn else GatLinear_rel(args.hidden_dim)] |
|
|
self.gather = nn.ModuleList(gats) |
|
|
elif self.args.attn_type == 'dotprod': |
|
|
gats = [] |
|
|
for _ in range(args.gnn_layers): |
|
|
gats += [GatDot(args.hidden_dim) if args.no_rel_attn else GatDot_rel(args.hidden_dim)] |
|
|
self.gather = nn.ModuleList(gats) |
|
|
elif self.args.attn_type == 'rgcn': |
|
|
|
|
|
gats_short = [] |
|
|
gats_long = [] |
|
|
for _ in range(args.gnn_layers): |
|
|
gats_short += [GAT_dialoggcn_v1(args.hidden_dim)] |
|
|
for _ in range(args.gnn_layers): |
|
|
gats_long += [GAT_dialoggcn_v1(args.hidden_dim)] |
|
|
self.gather_short = nn.ModuleList(gats_short) |
|
|
self.gather_long = nn.ModuleList(gats_long) |
|
|
|
|
|
|
|
|
grus_c_short = [] |
|
|
for _ in range(args.gnn_layers): |
|
|
grus_c_short += [nn.GRUCell(args.hidden_dim, args.hidden_dim)] |
|
|
self.grus_c_short = nn.ModuleList(grus_c_short) |
|
|
|
|
|
|
|
|
grus_c_long = [] |
|
|
for _ in range(args.gnn_layers): |
|
|
grus_c_long += [nn.GRUCell(args.hidden_dim, args.hidden_dim)] |
|
|
self.grus_c_long = nn.ModuleList(grus_c_long) |
|
|
|
|
|
grus_p_short = [] |
|
|
for _ in range(args.gnn_layers): |
|
|
grus_p_short += [nn.GRUCell(args.hidden_dim, args.hidden_dim)] |
|
|
self.grus_p_short = nn.ModuleList(grus_p_short) |
|
|
|
|
|
grus_p_long = [] |
|
|
for _ in range(args.gnn_layers): |
|
|
grus_p_long += [nn.GRUCell(args.hidden_dim, args.hidden_dim)] |
|
|
self.grus_p_long = nn.ModuleList(grus_p_long) |
|
|
|
|
|
|
|
|
fcs_short = [] |
|
|
for _ in range(args.gnn_layers): |
|
|
fcs_short += [nn.Linear(args.hidden_dim * 2, args.hidden_dim)] |
|
|
self.fcs_short = nn.ModuleList(fcs_short) |
|
|
|
|
|
|
|
|
fcs_long = [] |
|
|
for _ in range(args.gnn_layers): |
|
|
fcs_long += [nn.Linear(args.hidden_dim * 2, args.hidden_dim)] |
|
|
self.fcs_long = nn.ModuleList(fcs_long) |
|
|
|
|
|
|
|
|
self.fc1 = nn.Linear(args.emb_dim, args.hidden_dim) |
|
|
|
|
|
self.nodal_att_type = args.nodal_att_type |
|
|
|
|
|
in_dim = (args.hidden_dim * (args.gnn_layers + 1)) + args.emb_dim |
|
|
|
|
|
|
|
|
layers = [nn.Linear(in_dim, args.hidden_dim), nn.ReLU()] |
|
|
for _ in range(args.mlp_layers - 1): |
|
|
layers += [nn.Linear(args.hidden_dim, args.hidden_dim), nn.ReLU()] |
|
|
layers += [self.dropout] |
|
|
layers += [nn.Linear(args.hidden_dim, num_class)] |
|
|
|
|
|
self.out_mlp = nn.Sequential(*layers) |
|
|
|
|
|
self.attentive_node_features = attentive_node_features(in_dim) |
|
|
|
|
|
def forward(self, features, adj_1, adj_2 ,s_mask, s_mask_onehot, lengths): |
|
|
|
|
|
are_equal = all(torch.equal(h1, h2) for h1, h2 in zip(adj_1, adj_2)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
num_utter = features.size()[1] |
|
|
|
|
|
H0 = F.relu(self.fc1(features)) |
|
|
|
|
|
|
|
|
H = [H0] |
|
|
H_combined_short_list = [] |
|
|
|
|
|
for l in range(self.args.gnn_layers): |
|
|
C = self.grus_c_short[l](H[l][:,0,:]).unsqueeze(1) |
|
|
M = torch.zeros_like(C).squeeze(1) |
|
|
|
|
|
P = self.grus_p_short[l](M, H[l][:,0,:]).unsqueeze(1) |
|
|
|
|
|
|
|
|
H1 = C+P |
|
|
for i in range(1, num_utter): |
|
|
|
|
|
if self.args.attn_type == 'rgcn': |
|
|
|
|
|
|
|
|
_, M = self.gather_short[l](H[l][:,i,:], H1, H1, adj_1[:,i,:i], s_mask[:,i,:i]) |
|
|
|
|
|
else: |
|
|
if not self.rel_attn: |
|
|
_, M = self.gather_short[l](H[l][:,i,:], H1, H1, adj_1[:,i,:i]) |
|
|
else: |
|
|
_, M = self.gather_short[l](H[l][:,i,:], H1, H1, adj_1[:,i,:i], s_mask[:, i, :i]) |
|
|
|
|
|
|
|
|
|
|
|
C = self.grus_c_short[l](H[l][:,i,:], M).unsqueeze(1) |
|
|
|
|
|
P = self.grus_p_short[l](M, H[l][:,i,:]).unsqueeze(1) |
|
|
|
|
|
|
|
|
|
|
|
H_temp = C+P |
|
|
H1 = torch.cat((H1 , H_temp), dim = 1) |
|
|
|
|
|
|
|
|
H.append(H1) |
|
|
''' |
|
|
下面对长距离特征进行处理 |
|
|
''' |
|
|
H_long = [H0] |
|
|
H_combined_long_list = [] |
|
|
|
|
|
|
|
|
for l in range(self.args.gnn_layers): |
|
|
C_long = self.grus_c_long[l](H_long[l][:,0,:]).unsqueeze(1) |
|
|
M_long = torch.zeros_like(C_long).squeeze(1) |
|
|
P_long = self.grus_p_long[l](M_long, H_long[l][:,0,:]).unsqueeze(1) |
|
|
|
|
|
H1_long = C_long + P_long |
|
|
for i in range(1, num_utter): |
|
|
|
|
|
if self.args.attn_type == 'rgcn': |
|
|
_, M_long = self.gather_long[l](H_long[l][:,i,:], H1_long, H1_long, adj_2[:,i,:i], s_mask[:,i,:i]) |
|
|
else: |
|
|
if not self.rel_attn: |
|
|
_, M_long = self.gather_long[l](H_long[l][:,i,:], H1_long, H1_long, adj_2[:,i,:i]) |
|
|
else: |
|
|
_, M_long = self.gather_long[l](H_long[l][:,i,:], H1_long, H1_long, adj_2[:,i,:i], s_mask[:,i,:i]) |
|
|
|
|
|
|
|
|
C_long = self.grus_c_long[l](H_long[l][:,i,:], M_long).unsqueeze(1) |
|
|
P_long = self.grus_p_long[l](M_long, H_long[l][:,i,:]).unsqueeze(1) |
|
|
|
|
|
H_temp_long = C_long + P_long |
|
|
H1_long = torch.cat((H1_long, H_temp_long), dim=1) |
|
|
H_long.append(H1_long) |
|
|
|
|
|
|
|
|
|
|
|
H_combined = torch.cat(H, dim=2) |
|
|
H_long_combined = torch.cat(H_long, dim=2) |
|
|
sum_features = H_combined + H_long_combined |
|
|
|
|
|
|
|
|
H_combined_final = torch.cat((sum_features, features), dim=2) |
|
|
|
|
|
H_final = self.attentive_node_features(H_combined_final,lengths,self.nodal_att_type) |
|
|
|
|
|
logits = self.out_mlp(H_final) |
|
|
|
|
|
return logits |
|
|
|
|
|
|
|
|
|
|
|
class DAGERC_new_4(nn.Module): |
|
|
|
|
|
def __init__(self, args, num_class): |
|
|
super().__init__() |
|
|
self.args = args |
|
|
|
|
|
|
|
|
self.dropout = nn.Dropout(args.dropout) |
|
|
|
|
|
self.gnn_layers = args.gnn_layers |
|
|
|
|
|
if not args.no_rel_attn: |
|
|
self.rel_attn = True |
|
|
else: |
|
|
self.rel_attn = False |
|
|
|
|
|
if self.args.attn_type == 'linear': |
|
|
gats = [] |
|
|
for _ in range(args.gnn_layers): |
|
|
gats += [GatLinear(args.hidden_dim) if args.no_rel_attn else GatLinear_rel(args.hidden_dim)] |
|
|
self.gather = nn.ModuleList(gats) |
|
|
elif self.args.attn_type == 'dotprod': |
|
|
gats = [] |
|
|
for _ in range(args.gnn_layers): |
|
|
gats += [GatDot(args.hidden_dim) if args.no_rel_attn else GatDot_rel(args.hidden_dim)] |
|
|
self.gather = nn.ModuleList(gats) |
|
|
elif self.args.attn_type == 'rgcn': |
|
|
gats_short = [] |
|
|
gats_long = [] |
|
|
for _ in range(args.gnn_layers): |
|
|
gats_short += [GAT_dialoggcn_v1(args.hidden_dim)] |
|
|
for _ in range(args.gnn_layers): |
|
|
gats_long += [GAT_dialoggcn_v1(args.hidden_dim)] |
|
|
self.gather_short = nn.ModuleList(gats_short) |
|
|
self.gather_long = nn.ModuleList(gats_long) |
|
|
|
|
|
|
|
|
|
|
|
grus_c_short = [] |
|
|
for _ in range(args.gnn_layers): |
|
|
grus_c_short += [nn.GRUCell(args.hidden_dim, args.hidden_dim)] |
|
|
self.grus_c_short = nn.ModuleList(grus_c_short) |
|
|
|
|
|
|
|
|
|
|
|
grus_c_long = [] |
|
|
for _ in range(args.gnn_layers): |
|
|
grus_c_long += [nn.GRUCell(args.hidden_dim, args.hidden_dim)] |
|
|
self.grus_c_long = nn.ModuleList(grus_c_long) |
|
|
|
|
|
grus_p_short = [] |
|
|
for _ in range(args.gnn_layers): |
|
|
grus_p_short += [nn.GRUCell(args.hidden_dim, args.hidden_dim)] |
|
|
self.grus_p_short = nn.ModuleList(grus_p_short) |
|
|
|
|
|
grus_p_long = [] |
|
|
for _ in range(args.gnn_layers): |
|
|
grus_p_long += [nn.GRUCell(args.hidden_dim, args.hidden_dim)] |
|
|
self.grus_p_long = nn.ModuleList(grus_p_long) |
|
|
|
|
|
|
|
|
|
|
|
fcs_short = [] |
|
|
for _ in range(args.gnn_layers): |
|
|
fcs_short += [nn.Linear(args.hidden_dim * 2, args.hidden_dim)] |
|
|
self.fcs_short = nn.ModuleList(fcs_short) |
|
|
|
|
|
|
|
|
|
|
|
fcs_long = [] |
|
|
for _ in range(args.gnn_layers): |
|
|
fcs_long += [nn.Linear(args.hidden_dim * 2, args.hidden_dim)] |
|
|
self.fcs_long = nn.ModuleList(fcs_long) |
|
|
|
|
|
|
|
|
self.fc1 = nn.Linear(args.emb_dim, args.hidden_dim) |
|
|
|
|
|
self.nodal_att_type = args.nodal_att_type |
|
|
|
|
|
in_dim = (((args.hidden_dim*2))*(args.gnn_layers + 1) + args.emb_dim) |
|
|
|
|
|
|
|
|
layers = [nn.Linear(in_dim, args.hidden_dim), nn.ReLU()] |
|
|
for _ in range(args.mlp_layers - 1): |
|
|
layers += [nn.Linear(args.hidden_dim, args.hidden_dim), nn.ReLU()] |
|
|
layers += [self.dropout] |
|
|
layers += [nn.Linear(args.hidden_dim, num_class)] |
|
|
|
|
|
self.out_mlp = nn.Sequential(*layers) |
|
|
|
|
|
self.attentive_node_features = attentive_node_features(in_dim) |
|
|
|
|
|
self.affine1 = nn.Parameter(torch.empty(size=((args.hidden_dim) , (args.hidden_dim) ))) |
|
|
nn.init.xavier_uniform_(self.affine1.data, gain=1.414) |
|
|
self.affine2 = nn.Parameter(torch.empty(size=((args.hidden_dim) , (args.hidden_dim) ))) |
|
|
nn.init.xavier_uniform_(self.affine2.data, gain=1.414) |
|
|
|
|
|
self.diff_loss = DiffLoss(args) |
|
|
self.beta = args.diffloss |
|
|
|
|
|
def forward(self, features, adj_1, adj_2 ,s_mask,s_mask_onehot, lengths): |
|
|
|
|
|
are_equal = all(torch.equal(h1, h2) for h1, h2 in zip(adj_1, adj_2)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
num_utter = features.size()[1] |
|
|
|
|
|
H0 = F.relu(self.fc1(features)) |
|
|
|
|
|
|
|
|
H = [H0] |
|
|
H_combined_short_list = [] |
|
|
|
|
|
for l in range(self.args.gnn_layers): |
|
|
C = self.grus_c_short[l](H[l][:,0,:]).unsqueeze(1) |
|
|
M = torch.zeros_like(C).squeeze(1) |
|
|
|
|
|
P = self.grus_p_short[l](M, H[l][:,0,:]).unsqueeze(1) |
|
|
|
|
|
|
|
|
H1 = C+P |
|
|
for i in range(1, num_utter): |
|
|
|
|
|
if self.args.attn_type == 'rgcn': |
|
|
_, M = self.gather_short[l](H[l][:,i,:], H1, H1, adj_1[:,i,:i], s_mask[:,i,:i]) |
|
|
|
|
|
else: |
|
|
if not self.rel_attn: |
|
|
_, M = self.gather_short[l](H[l][:,i,:], H1, H1, adj_1[:,i,:i]) |
|
|
else: |
|
|
_, M = self.gather_short[l](H[l][:,i,:], H1, H1, adj_1[:,i,:i], s_mask[:, i, :i]) |
|
|
|
|
|
|
|
|
C = self.grus_c_short[l](H[l][:,i,:], M).unsqueeze(1) |
|
|
|
|
|
P = self.grus_p_short[l](M, H[l][:,i,:]).unsqueeze(1) |
|
|
|
|
|
|
|
|
|
|
|
H_temp = C+P |
|
|
H1 = torch.cat((H1 , H_temp), dim = 1) |
|
|
|
|
|
|
|
|
H.append(H1) |
|
|
H_combined_short_list.append(H[l+1]) |
|
|
|
|
|
''' |
|
|
下面对长距离特征进行处理 The following processes the long-distance features. |
|
|
''' |
|
|
H_long = [H0] |
|
|
H_combined_long_list = [] |
|
|
|
|
|
|
|
|
for l in range(self.args.gnn_layers): |
|
|
C_long = self.grus_c_long[l](H_long[l][:,0,:]).unsqueeze(1) |
|
|
M_long = torch.zeros_like(C_long).squeeze(1) |
|
|
P_long = self.grus_p_long[l](M_long, H_long[l][:,0,:]).unsqueeze(1) |
|
|
|
|
|
H1_long = C_long + P_long |
|
|
for i in range(1, num_utter): |
|
|
|
|
|
if self.args.attn_type == 'rgcn': |
|
|
_, M_long = self.gather_long[l](H_long[l][:,i,:], H1_long, H1_long, adj_2[:,i,:i], s_mask[:,i,:i]) |
|
|
else: |
|
|
if not self.rel_attn: |
|
|
_, M_long = self.gather_long[l](H_long[l][:,i,:], H1_long, H1_long, adj_2[:,i,:i]) |
|
|
else: |
|
|
_, M_long = self.gather_long[l](H_long[l][:,i,:], H1_long, H1_long, adj_2[:,i,:i], s_mask[:,i,:i]) |
|
|
|
|
|
|
|
|
C_long = self.grus_c_long[l](H_long[l][:,i,:], M_long).unsqueeze(1) |
|
|
P_long = self.grus_p_long[l](M_long, H_long[l][:,i,:]).unsqueeze(1) |
|
|
|
|
|
H_temp_long = C_long + P_long |
|
|
H1_long = torch.cat((H1_long, H_temp_long), dim=1) |
|
|
H_long.append(H1_long) |
|
|
H_combined_long_list.append(H_long[l+1]) |
|
|
''' |
|
|
两个通道特征都提取完毕!Both short- and long-distance channel features have been extracted! |
|
|
''' |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
H_final = [] |
|
|
H_0_final = torch.cat([H0, H0], dim=2) |
|
|
H_final.append(H_0_final) |
|
|
|
|
|
|
|
|
diff_loss = 0 |
|
|
for l in range(self.args.gnn_layers): |
|
|
|
|
|
HShort_prime = H_combined_short_list[l] |
|
|
HLong_prime = H_combined_long_list[l] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
diff_loss = self.diff_loss(HShort_prime, HLong_prime) + diff_loss |
|
|
|
|
|
|
|
|
|
|
|
A1 = F.softmax(torch.bmm(torch.matmul(HShort_prime, self.affine1), torch.transpose(HLong_prime, 1, 2)), dim=2) |
|
|
A2 = F.softmax(torch.bmm(torch.matmul(HLong_prime, self.affine2), torch.transpose(HShort_prime, 1, 2)), dim=2) |
|
|
|
|
|
HShort_prime_new = torch.bmm(A1, HLong_prime) |
|
|
HLong_prime_new = torch.bmm(A2, HShort_prime) |
|
|
|
|
|
HShort_prime_out = self.dropout(HShort_prime_new) if l < self.args.gnn_layers - 1 else HShort_prime_new |
|
|
HLong_prime_out = self.dropout(HLong_prime_new) if l <self.args.gnn_layers - 1 else HLong_prime_new |
|
|
|
|
|
H_layer = torch.cat([HShort_prime_out, HLong_prime_out], dim=2) |
|
|
H_final.append(H_layer) |
|
|
H_final = torch.cat(H_final, dim=2) |
|
|
H_final = torch.cat([H_final, features], dim=2) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
H_final = self.attentive_node_features(H_final,lengths,self.nodal_att_type) |
|
|
|
|
|
logits = self.out_mlp(H_final) |
|
|
|
|
|
return logits, self.beta * diff_loss |