Upload 33 files
Browse files- .gitattributes +4 -35
- .gitignore +2 -0
- README.md +26 -3
- cl.py +77 -0
- data/IEMOCAP/dev_data_roberta.json.feature +3 -0
- data/IEMOCAP/dev_data_roberta_mm.json.feature +3 -0
- data/IEMOCAP/label_vocab.pkl +0 -0
- data/IEMOCAP/speaker_vocab.pkl +0 -0
- data/IEMOCAP/test_data_roberta.json.feature +3 -0
- data/IEMOCAP/test_data_roberta_mm.json.feature +3 -0
- data/IEMOCAP/train_data_roberta.json.feature +3 -0
- data/IEMOCAP/train_data_roberta_mm.json.feature +3 -0
- data/MELD/dev_data_roberta.json.feature +3 -0
- data/MELD/dev_data_roberta_mm.json.feature +3 -0
- data/MELD/label_vocab.pkl +0 -0
- data/MELD/speaker_vocab.pkl +0 -0
- data/MELD/test_data_roberta.json.feature +3 -0
- data/MELD/test_data_roberta_mm.json.feature +3 -0
- data/MELD/train_data_roberta.json.feature +3 -0
- data/MELD/train_data_roberta_mm.json.feature +3 -0
- dataloader.py +81 -0
- dataset.py +230 -0
- evaluate.py +199 -0
- model.py +1199 -0
- model_utils.py +507 -0
- requirements.txt +7 -0
- run.py +239 -0
- saved_models/IEMOCAP/README.txt +2 -0
- saved_models/MELD/README.txt +2 -0
- saved_models/README.txt +2 -0
- similarity_matrix.py +101 -0
- trainer.py +171 -0
- utils.py +16 -0
.gitattributes
CHANGED
|
@@ -1,35 +1,4 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
*.
|
| 4 |
-
*.
|
| 5 |
-
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
-
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
-
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
-
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
-
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
| 1 |
+
# Auto detect text files and perform LF normalization
|
| 2 |
+
* text=auto
|
| 3 |
+
*.feature filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.json.feature filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.gitignore
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.feature
|
| 2 |
+
*.pkl
|
README.md
CHANGED
|
@@ -1,3 +1,26 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Long-Short Distance Graph Neural Networks and Improved Curriculum Learning for Emotion Recognition in Conversation (Accepted by ECAI2025)
|
| 2 |
+
|
| 3 |
+
Emotion Recognition in Conversation (ERC) is a practical and challenging task. This paper proposes a novel multimodal approach, the Long-Short Distance Graph Neural Network (LSDGNN). Based on the Directed Acyclic Graph (DAG), it constructs a long-distance graph neural network and a short-distance graph neural network to obtain multimodal features of distant and nearby utterances, respectively. To ensure that long- and short-distance features are as distinct as possible in representation while enabling mutual influence between the two modules, we employ a Differential Regularizer and incorporate a BiAffine Module to facilitate feature interaction. In addition, we propose an Improved Curriculum Learning (ICL) to address the challenge of data imbalance. By computing the similarity between different emotions to emphasize the shifts in similar emotions, we design a "weighted emotional shift" metric and develop a difficulty measurer, enabling a training process that prioritizes learning easy samples before harder ones. Experimental results on the IEMOCAP and MELD datasets demonstrate that our model outperforms existing benchmarks.
|
| 4 |
+
|
| 5 |
+
## Requirements
|
| 6 |
+
Python 3.11
|
| 7 |
+
CUDA 12.2
|
| 8 |
+
|
| 9 |
+
After configuring the Python environment and CUDA, you can use `pip install -r requirements.txt` to install the following libraries.
|
| 10 |
+
|
| 11 |
+
torch==2.0.0+cu117
|
| 12 |
+
transformers==4.46.3
|
| 13 |
+
numpy==1.24.2
|
| 14 |
+
pandas==2.1.4
|
| 15 |
+
matplotlib==3.7.1
|
| 16 |
+
scikit-learn==1.2.2
|
| 17 |
+
tqdm==4.67.1
|
| 18 |
+
|
| 19 |
+
### Training
|
| 20 |
+
GPU NVIDIA GeForce RTX 3090
|
| 21 |
+
|
| 22 |
+
for IEMOCAP:
|
| 23 |
+
`python run.py --dataset_name IEMOCAP --gnn_layers 4 --lr 0.0005 --batch_size 16 --epochs 30 --dropout 0.4 --emb_dim 2948 --windowpl 5 --diffloss 0.1 --curriculum --bucket_number 5`
|
| 24 |
+
|
| 25 |
+
for MELD:
|
| 26 |
+
`python run.py --dataset_name MELD --gnn_layers 2 --lr 0.00001 --batch_size 64 --epochs 30 --dropout 0.1 --emb_dim 1666 --windowpl 5 --diffloss 0.2 --curriculum --bucket_number 12`
|
cl.py
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import similarity_matrix
|
| 3 |
+
|
| 4 |
+
class Dialog:
|
| 5 |
+
def __init__(self, utterances, labels, speakers, features, dataset):
|
| 6 |
+
self.utterances = utterances
|
| 7 |
+
self.labels = labels
|
| 8 |
+
self.speakers = speakers
|
| 9 |
+
self.features = features
|
| 10 |
+
self.dataset = dataset
|
| 11 |
+
self.numberofemotionshifts = 0
|
| 12 |
+
self.numberofspeakers = 0
|
| 13 |
+
self.numberofutterances = 0
|
| 14 |
+
self.difficulty = 0
|
| 15 |
+
self.emotion_variance = 0 # 情感变化度量
|
| 16 |
+
self.emotion_shift_weighted = 0 #加权后的情感变化
|
| 17 |
+
self.cc()
|
| 18 |
+
|
| 19 |
+
def __getitem__(self, item):
|
| 20 |
+
if item == 'utterances':
|
| 21 |
+
return self.utterances
|
| 22 |
+
elif item == 'labels':
|
| 23 |
+
return self.labels
|
| 24 |
+
elif item == 'speakers':
|
| 25 |
+
return self.speakers
|
| 26 |
+
elif item == 'features':
|
| 27 |
+
return self.features
|
| 28 |
+
#measure the difficulty of a dialog
|
| 29 |
+
def cc(self):
|
| 30 |
+
# 情感数字到文字的映射字典
|
| 31 |
+
# print(self.dataset)
|
| 32 |
+
if self.dataset == 'MELD':
|
| 33 |
+
emotion_map = { -1: 'null', 0: 'neutral', 1: 'surprise', 2: 'fear', 3: 'sadness', 4: 'joy', 5: 'disgust', 6: 'anger'}
|
| 34 |
+
else:
|
| 35 |
+
emotion_map = { -1: 'null', 0:'excitement', 1: 'neutral', 2:'frustration', 3:'sadness', 4:'happiness', 5:'anger'}
|
| 36 |
+
# print(emotion_map)
|
| 37 |
+
self.numberofutterances = len(self.utterances)
|
| 38 |
+
speaker_emo = {}
|
| 39 |
+
for i in range(0, len(self.labels)):
|
| 40 |
+
if (self.speakers[i] in speaker_emo):
|
| 41 |
+
speaker_emo[self.speakers[i]].append(emotion_map[self.labels[i]])
|
| 42 |
+
else:
|
| 43 |
+
speaker_emo[self.speakers[i]] = [emotion_map[self.labels[i]]]
|
| 44 |
+
|
| 45 |
+
# 获取情感相似度矩阵
|
| 46 |
+
matrix, emotion_to_index = similarity_matrix.get_similarity_matrix(self.dataset)
|
| 47 |
+
# print(matrix)
|
| 48 |
+
k = 1
|
| 49 |
+
b = 0.4
|
| 50 |
+
for key in speaker_emo:
|
| 51 |
+
# prev_emo = None
|
| 52 |
+
for i in range(0, len(speaker_emo[key]) - 1):
|
| 53 |
+
current_emo = speaker_emo[key][i]
|
| 54 |
+
next_emo = speaker_emo[key][i + 1]
|
| 55 |
+
if current_emo != next_emo and current_emo != 'null' and next_emo != 'null':
|
| 56 |
+
self.numberofemotionshifts += 1
|
| 57 |
+
current_emo_index = emotion_to_index[current_emo]
|
| 58 |
+
next_emo_index = emotion_to_index[next_emo]
|
| 59 |
+
#线性缩放
|
| 60 |
+
#当k为正数时,similarity_score越小说明差距越大,越大说明差距越小,侧重于差距小的情感
|
| 61 |
+
#当k为负数时,反之,侧重于差距大的情感
|
| 62 |
+
similarity_score = abs(matrix[current_emo_index][next_emo_index]) * k + b
|
| 63 |
+
self.emotion_shift_weighted += similarity_score
|
| 64 |
+
|
| 65 |
+
#print(speaker_emo[key])
|
| 66 |
+
'''
|
| 67 |
+
for key in speaker_emo:
|
| 68 |
+
# Convert labels to indices
|
| 69 |
+
emotions = speaker_emo[key]
|
| 70 |
+
self.emotion_variance += np.std(emotions) # 计算每个发言人的情感方差
|
| 71 |
+
'''
|
| 72 |
+
# print(self.numberofemotionshifts)
|
| 73 |
+
# print(self.emotion_shift_weighted)
|
| 74 |
+
# print('---------')
|
| 75 |
+
self.numberofspeakers = len(set(self.speakers))
|
| 76 |
+
self.difficulty = (self.emotion_shift_weighted + self.numberofspeakers ) / (self.numberofutterances + self.numberofspeakers)
|
| 77 |
+
# self.difficulty = (self.numberofemotionshifts + self.numberofspeakers ) / (self.numberofutterances + self.numberofspeakers)
|
data/IEMOCAP/dev_data_roberta.json.feature
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c8c1835251e4c85af23e65d1a312990b411d36574f9bd1063f6ca3d20d8f2eda
|
| 3 |
+
size 32689394
|
data/IEMOCAP/dev_data_roberta_mm.json.feature
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6bd2b4d3fd2a092c2f7b390a0c5115533bf4fb5b0724f49e76a84d98bf054bdb
|
| 3 |
+
size 62364912
|
data/IEMOCAP/label_vocab.pkl
ADDED
|
Binary file (98 Bytes). View file
|
|
|
data/IEMOCAP/speaker_vocab.pkl
ADDED
|
Binary file (7.61 kB). View file
|
|
|
data/IEMOCAP/test_data_roberta.json.feature
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6f7c81d1690f997cd02672b3dd1f077744c997c0446b9ce6599a54f37db4950f
|
| 3 |
+
size 51258384
|
data/IEMOCAP/test_data_roberta_mm.json.feature
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5bcf5c03d714bbf27105e72141a6e78f1d58278d794e46be82cee53beddcfab1
|
| 3 |
+
size 103033265
|
data/IEMOCAP/train_data_roberta.json.feature
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6299f2656d0d4e7a7c2e1f91c7e0f811088d500c6320b741ce742a72d6993c99
|
| 3 |
+
size 151409987
|
data/IEMOCAP/train_data_roberta_mm.json.feature
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:cdc380726af732912b539b335233053a465237d29c7369477851bf8dc62e5f28
|
| 3 |
+
size 307542106
|
data/MELD/dev_data_roberta.json.feature
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:2fe9a5f3b0464ff597061a833fa62cfe2337311b53699d1878bba1827106a0a8
|
| 3 |
+
size 29797677
|
data/MELD/dev_data_roberta_mm.json.feature
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:01930f39300566f37d433a60c323eedc2a8cfd53486d5f5cbda56371e85f8aff
|
| 3 |
+
size 42149762
|
data/MELD/label_vocab.pkl
ADDED
|
Binary file (128 Bytes). View file
|
|
|
data/MELD/speaker_vocab.pkl
ADDED
|
Binary file (5.08 kB). View file
|
|
|
data/MELD/test_data_roberta.json.feature
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0a60fe91031fe79749fa813336309ae5a1804561ff70dd5a6320029a4fdee0f0
|
| 3 |
+
size 70132507
|
data/MELD/test_data_roberta_mm.json.feature
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0cead63677c699daef4d382047e5592066922c6659d5bca9bc5ce57367dd65a7
|
| 3 |
+
size 99010213
|
data/MELD/train_data_roberta.json.feature
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:05d4c1be081d894d9f2baa44e3a86882ccaa77cd0783e9c9092b39ad31d3bc81
|
| 3 |
+
size 267972749
|
data/MELD/train_data_roberta_mm.json.feature
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:40b44a9d69baccc275fd3dcbb95aaf079ac32199d32b072c652e7691a5b46bf8
|
| 3 |
+
size 364368202
|
dataloader.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataset import *
|
| 2 |
+
import pickle
|
| 3 |
+
from torch.utils.data.sampler import SubsetRandomSampler
|
| 4 |
+
from torch.utils.data import DataLoader
|
| 5 |
+
import os
|
| 6 |
+
import argparse
|
| 7 |
+
import numpy as np
|
| 8 |
+
from transformers import BertTokenizer
|
| 9 |
+
|
| 10 |
+
def get_train_valid_sampler(trainset):
|
| 11 |
+
size = len(trainset)
|
| 12 |
+
idx = list(range(size))
|
| 13 |
+
return SubsetRandomSampler(idx)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def load_vocab(dataset_name):
|
| 17 |
+
speaker_vocab = pickle.load(open('data/%s/speaker_vocab.pkl' % (dataset_name), 'rb'))
|
| 18 |
+
label_vocab = pickle.load(open('data/%s/label_vocab.pkl' % (dataset_name), 'rb'))
|
| 19 |
+
person_vec_dir = 'data/%s/person_vect.pkl' % (dataset_name)
|
| 20 |
+
# if os.path.exists(person_vec_dir):
|
| 21 |
+
# print('Load person vec from ' + person_vec_dir)
|
| 22 |
+
# person_vec = pickle.load(open(person_vec_dir, 'rb'))
|
| 23 |
+
# else:
|
| 24 |
+
# print('Creating personality vectors')
|
| 25 |
+
# person_vec = np.random.randn(len(speaker_vocab['itos']), 100)a
|
| 26 |
+
# print('Saving personality vectors to' + person_vec_dir)
|
| 27 |
+
# with open(person_vec_dir,'wb') as f:
|
| 28 |
+
# pickle.dump(person_vec, f, -1)
|
| 29 |
+
person_vec = None
|
| 30 |
+
|
| 31 |
+
return speaker_vocab, label_vocab, person_vec
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def get_IEMOCAP_loaders(dataset_name = 'IEMOCAP', batch_size=32, num_workers=0, pin_memory=False, args = None):
|
| 35 |
+
print('building vocab.. ')
|
| 36 |
+
speaker_vocab, label_vocab, person_vec = load_vocab(dataset_name)
|
| 37 |
+
print('building datasets..')
|
| 38 |
+
devset = IEMOCAPDataset(dataset_name, 'dev', speaker_vocab, label_vocab, args)
|
| 39 |
+
valid_sampler = get_train_valid_sampler(devset)
|
| 40 |
+
testset = IEMOCAPDataset(dataset_name, 'test', speaker_vocab, label_vocab, args)
|
| 41 |
+
valid_loader = DataLoader(devset,
|
| 42 |
+
batch_size=batch_size,
|
| 43 |
+
sampler=valid_sampler,
|
| 44 |
+
collate_fn=devset.collate_fn,
|
| 45 |
+
num_workers=num_workers,
|
| 46 |
+
pin_memory=pin_memory)
|
| 47 |
+
|
| 48 |
+
test_loader = DataLoader(testset,
|
| 49 |
+
batch_size=batch_size,
|
| 50 |
+
collate_fn=testset.collate_fn,
|
| 51 |
+
num_workers=num_workers,
|
| 52 |
+
pin_memory=pin_memory)
|
| 53 |
+
|
| 54 |
+
return valid_loader, test_loader, speaker_vocab, label_vocab, person_vec
|
| 55 |
+
#adding babystep_index argument if using curriculum learning
|
| 56 |
+
def get_train_loader(dataset_name = 'IEMOCAP', batch_size=32, num_workers=0, pin_memory=False, args = None, babystep_index = None):
|
| 57 |
+
print('building vocab.. ')
|
| 58 |
+
speaker_vocab, label_vocab, person_vec = load_vocab(dataset_name)
|
| 59 |
+
print('building datasets..')
|
| 60 |
+
if (args.curriculum):
|
| 61 |
+
trainset = IEMOCAPDataset(dataset_name, 'train', speaker_vocab, label_vocab, args, None, babystep_index)
|
| 62 |
+
train_sampler = get_train_valid_sampler(trainset)
|
| 63 |
+
train_loader = DataLoader(trainset,
|
| 64 |
+
batch_size=batch_size,
|
| 65 |
+
sampler=train_sampler,
|
| 66 |
+
collate_fn=trainset.collate_fn,
|
| 67 |
+
num_workers=num_workers,
|
| 68 |
+
pin_memory=pin_memory)
|
| 69 |
+
# train_loaders.append(train_loader)
|
| 70 |
+
else:
|
| 71 |
+
trainset = IEMOCAPDataset(dataset_name, 'train', speaker_vocab, label_vocab, args)
|
| 72 |
+
train_sampler = get_train_valid_sampler(trainset)
|
| 73 |
+
train_loader = DataLoader(trainset,
|
| 74 |
+
batch_size=batch_size,
|
| 75 |
+
sampler=train_sampler,
|
| 76 |
+
collate_fn=trainset.collate_fn,
|
| 77 |
+
num_workers=num_workers,
|
| 78 |
+
pin_memory=pin_memory)
|
| 79 |
+
# train_loaders.append(train_loader)
|
| 80 |
+
|
| 81 |
+
return train_loader
|
dataset.py
ADDED
|
@@ -0,0 +1,230 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch.utils.data import Dataset
|
| 3 |
+
from torch.nn.utils.rnn import pad_sequence
|
| 4 |
+
import pickle, pandas as pd
|
| 5 |
+
import json
|
| 6 |
+
import numpy as np
|
| 7 |
+
import random
|
| 8 |
+
from pandas import DataFrame
|
| 9 |
+
|
| 10 |
+
import cl
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class IEMOCAPDataset(Dataset):
|
| 14 |
+
#babystep_index:用于curriculum learning中的索引,指示在babystep方法中使用多少个“桶”。
|
| 15 |
+
def __init__(self, dataset_name = 'IEMOCAP', split = 'train', speaker_vocab=None, label_vocab=None, args = None, tokenizer = None, babystep_index = None):
|
| 16 |
+
self.speaker_vocab = speaker_vocab #说话者的词汇表
|
| 17 |
+
self.label_vocab = label_vocab #标签的词汇表
|
| 18 |
+
self.args = args #保存其他配置参数
|
| 19 |
+
self.data = self.read(dataset_name, split, tokenizer)#调用read方法加载数据集
|
| 20 |
+
if args.curriculum and split == 'train': #这行代码处理curriculum learning 如果args.curriculum为True且当前数据是训练集(split == 'train'),则进行babystep操作。
|
| 21 |
+
self.data = self.babystep(self.getbuckets(self.data, args.bucket_number), babystep_index)
|
| 22 |
+
print(len(self.data))
|
| 23 |
+
|
| 24 |
+
self.len = len(self.data)
|
| 25 |
+
|
| 26 |
+
def read(self, dataset_name, split, tokenizer):
|
| 27 |
+
with open('data/%s/%s_data_roberta_mm.json.feature'%(dataset_name, split), encoding='utf-8') as f:
|
| 28 |
+
raw_data = json.load(f)
|
| 29 |
+
|
| 30 |
+
# process dialogue
|
| 31 |
+
dialogs = []
|
| 32 |
+
# raw_data = sorted(raw_data, key=lambda x:len(x))
|
| 33 |
+
for d in raw_data:
|
| 34 |
+
# if len(d) < 5 or len(d) > 6:
|
| 35 |
+
# continue
|
| 36 |
+
utterances = []
|
| 37 |
+
labels = []
|
| 38 |
+
speakers = []
|
| 39 |
+
features = []
|
| 40 |
+
for i,u in enumerate(d):
|
| 41 |
+
utterances.append(u['text'])
|
| 42 |
+
labels.append(self.label_vocab['stoi'][u['label']] if 'label' in u.keys() else -1)
|
| 43 |
+
speakers.append(self.speaker_vocab['stoi'][u['speaker']])
|
| 44 |
+
features.append(u['cls'][0] + u['cls'][1]+u['cls'][2])
|
| 45 |
+
#different modalities
|
| 46 |
+
#features.append(u['cls'][0])
|
| 47 |
+
#features.append(u['cls'][1])
|
| 48 |
+
#features.append(u['cls'][2])
|
| 49 |
+
#features.append(u['cls'][0] + u['cls'][1])
|
| 50 |
+
#features.append(u['cls'][0] + u['cls'][2])
|
| 51 |
+
#features.append(u['cls'][1]+u['cls'][2])
|
| 52 |
+
dialog = cl.Dialog(utterances, labels, speakers, features, self.args.dataset_name)
|
| 53 |
+
# dialogs.append({
|
| 54 |
+
# 'utterances': utterances,
|
| 55 |
+
# 'labels': labels,
|
| 56 |
+
# 'speakers':speakers,
|
| 57 |
+
# 'features': features
|
| 58 |
+
# })
|
| 59 |
+
dialogs.append(dialog)
|
| 60 |
+
if self.args.curriculum and split == 'train':
|
| 61 |
+
totalut = 0
|
| 62 |
+
totalshift = 0
|
| 63 |
+
totalspeaker = 0
|
| 64 |
+
for i in range(0, len(dialogs)):
|
| 65 |
+
totalut += dialogs[i].numberofutterances
|
| 66 |
+
totalshift += dialogs[i].numberofemotionshifts
|
| 67 |
+
totalspeaker += dialogs[i].numberofspeakers
|
| 68 |
+
# random.shuffle(dialogs)
|
| 69 |
+
dialogs.sort(key= lambda dialog: dialog.difficulty)
|
| 70 |
+
# if (split == 'train'):
|
| 71 |
+
# num_buckets = 8
|
| 72 |
+
# bucket_length = (len(dialogs) + num_buckets - 1) // num_buckets
|
| 73 |
+
# buckets = [dialogs[i:i + bucket_length] for i in range(0, len(dialogs), bucket_length)]
|
| 74 |
+
# print('')
|
| 75 |
+
else:
|
| 76 |
+
random.shuffle(dialogs)
|
| 77 |
+
return dialogs
|
| 78 |
+
|
| 79 |
+
def getbuckets(self, dialogs, num_buckets):
|
| 80 |
+
buckets = []
|
| 81 |
+
bucket_length = (len(dialogs) + num_buckets - 1) // num_buckets
|
| 82 |
+
buckets = [dialogs[i:i + bucket_length] for i in range(0, len(dialogs), bucket_length)]
|
| 83 |
+
print('bucket')
|
| 84 |
+
print(len(buckets))
|
| 85 |
+
return buckets
|
| 86 |
+
#parameter for curriculum learning
|
| 87 |
+
def babystep(self, buckets, index):
|
| 88 |
+
data = []
|
| 89 |
+
for i in range(0, index):
|
| 90 |
+
data+= buckets[i];
|
| 91 |
+
return data
|
| 92 |
+
def __getitem__(self, index):
|
| 93 |
+
'''
|
| 94 |
+
:param index:
|
| 95 |
+
:return:
|
| 96 |
+
feature,
|
| 97 |
+
label
|
| 98 |
+
speaker
|
| 99 |
+
length
|
| 100 |
+
text
|
| 101 |
+
'''
|
| 102 |
+
return torch.FloatTensor(self.data[index]['features']), \
|
| 103 |
+
torch.LongTensor(self.data[index]['labels']),\
|
| 104 |
+
self.data[index]['speakers'], \
|
| 105 |
+
len(self.data[index]['labels']), \
|
| 106 |
+
self.data[index]['utterances']
|
| 107 |
+
|
| 108 |
+
def __len__(self):
|
| 109 |
+
return self.len
|
| 110 |
+
|
| 111 |
+
def get_adj(self, speakers, max_dialog_len):
|
| 112 |
+
'''
|
| 113 |
+
get adj matrix
|
| 114 |
+
:param speakers: (B, N)
|
| 115 |
+
:param max_dialog_len:
|
| 116 |
+
:return:
|
| 117 |
+
adj: (B, N, N) adj[:,i,:] means the direct predecessors of node i
|
| 118 |
+
'''
|
| 119 |
+
adj = []
|
| 120 |
+
for speaker in speakers:
|
| 121 |
+
a = torch.zeros(max_dialog_len, max_dialog_len)
|
| 122 |
+
for i,s in enumerate(speaker):
|
| 123 |
+
get_local_pred = False
|
| 124 |
+
get_global_pred = False
|
| 125 |
+
for j in range(i - 1, -1, -1):
|
| 126 |
+
if speaker[j] == s and not get_local_pred:
|
| 127 |
+
get_local_pred = True
|
| 128 |
+
a[i,j] = 1
|
| 129 |
+
elif speaker[j] != s and not get_global_pred:
|
| 130 |
+
get_global_pred = True
|
| 131 |
+
a[i,j] = 1
|
| 132 |
+
if get_global_pred and get_local_pred:
|
| 133 |
+
break
|
| 134 |
+
adj.append(a)
|
| 135 |
+
return torch.stack(adj)
|
| 136 |
+
|
| 137 |
+
def get_adj_v1(self, speakers, max_dialog_len):
|
| 138 |
+
'''
|
| 139 |
+
get adj matrix
|
| 140 |
+
:param speakers: (B, N)
|
| 141 |
+
:param max_dialog_len:
|
| 142 |
+
:return:
|
| 143 |
+
adj: (B, N, N) adj[:,i,:] means the direct predecessors of node i
|
| 144 |
+
'''
|
| 145 |
+
adj = []
|
| 146 |
+
for speaker in speakers:
|
| 147 |
+
a = torch.zeros(max_dialog_len, max_dialog_len)
|
| 148 |
+
for i,s in enumerate(speaker):
|
| 149 |
+
cnt = 0
|
| 150 |
+
for j in range(i - 1, -1, -1):
|
| 151 |
+
a[i,j] = 1
|
| 152 |
+
if speaker[j] == s:
|
| 153 |
+
cnt += 1
|
| 154 |
+
if cnt==self.args.windowps:
|
| 155 |
+
break
|
| 156 |
+
adj.append(a)
|
| 157 |
+
return torch.stack(adj)
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
def get_adj_v2(self, speakers, max_dialog_len):
|
| 161 |
+
'''
|
| 162 |
+
get adj matrix
|
| 163 |
+
:param speakers: (N)
|
| 164 |
+
:param max_dialog_len:
|
| 165 |
+
:return:
|
| 166 |
+
adj: (N, N) adj[i,:] means the direct predecessors of node i
|
| 167 |
+
'''
|
| 168 |
+
adj = []
|
| 169 |
+
for speaker in speakers:
|
| 170 |
+
a = torch.zeros(max_dialog_len, max_dialog_len)
|
| 171 |
+
for i, s in enumerate(speaker):
|
| 172 |
+
cnt = 0
|
| 173 |
+
for j in range(i - 1, -1, -1):
|
| 174 |
+
a[i, j] = 1 # Assign 1 for all previous utterances
|
| 175 |
+
if speaker[j] == s: # Compare speaker strings
|
| 176 |
+
cnt += 1
|
| 177 |
+
if cnt == self.args.windowpl: # Check if window condition is met
|
| 178 |
+
break
|
| 179 |
+
adj.append(a)
|
| 180 |
+
|
| 181 |
+
return torch.stack(adj)
|
| 182 |
+
|
| 183 |
+
def get_s_mask(self, speakers, max_dialog_len):
|
| 184 |
+
'''
|
| 185 |
+
:param speakers:
|
| 186 |
+
:param max_dialog_len:
|
| 187 |
+
:return:
|
| 188 |
+
s_mask: (B, N, N) s_mask[:,i,:] means the speaker informations for predecessors of node i, where 1 denotes the same speaker, 0 denotes the different speaker
|
| 189 |
+
s_mask_onehot (B, N, N, 2) onehot emcoding of s_mask
|
| 190 |
+
'''
|
| 191 |
+
s_mask = []
|
| 192 |
+
s_mask_onehot = []
|
| 193 |
+
for speaker in speakers:
|
| 194 |
+
s = torch.zeros(max_dialog_len, max_dialog_len, dtype = torch.long)
|
| 195 |
+
s_onehot = torch.zeros(max_dialog_len, max_dialog_len, 2)
|
| 196 |
+
for i in range(len(speaker)):
|
| 197 |
+
for j in range(len(speaker)):
|
| 198 |
+
if speaker[i] == speaker[j]:
|
| 199 |
+
s[i,j] = 1
|
| 200 |
+
s_onehot[i,j,1] = 1
|
| 201 |
+
else:
|
| 202 |
+
s_onehot[i,j,0] = 1
|
| 203 |
+
|
| 204 |
+
s_mask.append(s)
|
| 205 |
+
s_mask_onehot.append(s_onehot)
|
| 206 |
+
return torch.stack(s_mask), torch.stack(s_mask_onehot)
|
| 207 |
+
|
| 208 |
+
def collate_fn(self, data):
|
| 209 |
+
'''
|
| 210 |
+
:param data:
|
| 211 |
+
features, labels, speakers, length, utterances
|
| 212 |
+
:return:
|
| 213 |
+
features: (B, N, D) padded
|
| 214 |
+
labels: (B, N) padded
|
| 215 |
+
adj: (B, N, N) adj[:,i,:] means the direct predecessors of node i
|
| 216 |
+
s_mask: (B, N, N) s_mask[:,i,:] means the speaker informations for predecessors of node i, where 1 denotes the same speaker, 0 denotes the different speaker
|
| 217 |
+
lengths: (B, )
|
| 218 |
+
utterances: not a tensor
|
| 219 |
+
'''
|
| 220 |
+
max_dialog_len = max([d[3] for d in data])
|
| 221 |
+
feaures = pad_sequence([d[0] for d in data], batch_first = True) # (B, N, D)
|
| 222 |
+
labels = pad_sequence([d[1] for d in data], batch_first = True, padding_value = -1) # (B, N )
|
| 223 |
+
adj_1 = self.get_adj_v1([d[2] for d in data], max_dialog_len)
|
| 224 |
+
adj_2 = self.get_adj_v2([d[2] for d in data], max_dialog_len)
|
| 225 |
+
s_mask, s_mask_onehot = self.get_s_mask([d[2] for d in data], max_dialog_len)
|
| 226 |
+
lengths = torch.LongTensor([d[3] for d in data])
|
| 227 |
+
speakers = pad_sequence([torch.LongTensor(d[2]) for d in data], batch_first = True, padding_value = -1)
|
| 228 |
+
utterances = [d[4] for d in data]
|
| 229 |
+
|
| 230 |
+
return feaures, labels, adj_1, adj_2, s_mask, s_mask_onehot,lengths, speakers, utterances
|
evaluate.py
ADDED
|
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = '0'
|
| 3 |
+
import numpy as np, argparse, time, pickle, random
|
| 4 |
+
import torch
|
| 5 |
+
import matplotlib
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import torch.optim as optim
|
| 8 |
+
from dataloader import IEMOCAPDataset
|
| 9 |
+
from model import *
|
| 10 |
+
from sklearn.metrics import f1_score, confusion_matrix, accuracy_score, classification_report, \
|
| 11 |
+
precision_recall_fscore_support, ConfusionMatrixDisplay
|
| 12 |
+
import matplotlib.pyplot as plt
|
| 13 |
+
from trainer import train_or_eval_model, save_badcase
|
| 14 |
+
from dataset import IEMOCAPDataset
|
| 15 |
+
from dataloader import get_IEMOCAP_loaders
|
| 16 |
+
from transformers import AdamW
|
| 17 |
+
import copy
|
| 18 |
+
|
| 19 |
+
# We use seed = 100 for reproduction of the results reported in the paper.
|
| 20 |
+
seed = 100
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def seed_everything(seed=seed):
|
| 24 |
+
random.seed(seed)
|
| 25 |
+
np.random.seed(seed)
|
| 26 |
+
torch.manual_seed(seed)
|
| 27 |
+
torch.cuda.manual_seed(seed)
|
| 28 |
+
torch.cuda.manual_seed_all(seed)
|
| 29 |
+
torch.backends.cudnn.benchmark = False
|
| 30 |
+
torch.backends.cudnn.deterministic = True
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def evaluate(model, dataloader, cuda, args, speaker_vocab, label_vocab):
|
| 34 |
+
preds, labels = [], []
|
| 35 |
+
scores, vids = [], []
|
| 36 |
+
dialogs = []
|
| 37 |
+
speakers = []
|
| 38 |
+
|
| 39 |
+
model.eval()
|
| 40 |
+
|
| 41 |
+
for data in dataloader:
|
| 42 |
+
|
| 43 |
+
features, label, adj,s_mask, s_mask_onehot,lengths, speaker, utterances = data
|
| 44 |
+
if cuda:
|
| 45 |
+
features = features.cuda()
|
| 46 |
+
label = label.cuda()
|
| 47 |
+
adj = adj.cuda()
|
| 48 |
+
s_mask_onehot = s_mask_onehot.cuda()
|
| 49 |
+
s_mask = s_mask.cuda()
|
| 50 |
+
lengths = lengths.cuda()
|
| 51 |
+
|
| 52 |
+
log_prob = model(features, adj,s_mask, s_mask_onehot, lengths) # (B, N, C)
|
| 53 |
+
|
| 54 |
+
label = label.cpu().numpy().tolist() # (B, N)
|
| 55 |
+
pred = torch.argmax(log_prob, dim = 2).cpu().numpy().tolist() # (B, N)
|
| 56 |
+
preds += pred
|
| 57 |
+
labels += label
|
| 58 |
+
dialogs += utterances
|
| 59 |
+
speakers += speaker
|
| 60 |
+
|
| 61 |
+
if preds != []:
|
| 62 |
+
new_preds = []
|
| 63 |
+
new_labels = []
|
| 64 |
+
for i,label in enumerate(labels):
|
| 65 |
+
for j,l in enumerate(label):
|
| 66 |
+
if l != -1:
|
| 67 |
+
new_labels.append(l)
|
| 68 |
+
new_preds.append(preds[i][j])
|
| 69 |
+
else:
|
| 70 |
+
return
|
| 71 |
+
|
| 72 |
+
avg_accuracy = round(accuracy_score(new_labels, new_preds) * 100, 2)
|
| 73 |
+
if args.dataset_name in ['IEMOCAP', 'MELD', 'EmoryNLP']:
|
| 74 |
+
avg_fscore = round(f1_score(new_labels, new_preds, average='weighted') * 100, 2)
|
| 75 |
+
# get f1 score for each class to generate confusion matrix
|
| 76 |
+
# fscore_perclass = f1_score(new_labels, new_preds, average=None)
|
| 77 |
+
# print('fscore_perclass', fscore_perclass)
|
| 78 |
+
print('test_accuracy', avg_accuracy)
|
| 79 |
+
print('test_f1', avg_fscore)
|
| 80 |
+
# confusion matrix test, not working on colab
|
| 81 |
+
# print(new_labels)
|
| 82 |
+
# cm = confusion_matrix(new_labels, new_preds, labels=[0, 1, 2, 3, 4, 5, 6])
|
| 83 |
+
# print(cm)
|
| 84 |
+
# per_class_accuracies = {}
|
| 85 |
+
#
|
| 86 |
+
# # Calculate the accuracy for each one of our classes
|
| 87 |
+
# for idx, cls in enumerate(label_vocab['itos']):
|
| 88 |
+
# # True negatives are all the samples that are not our current GT class (not the current row)
|
| 89 |
+
# # and were not predicted as the current class (not the current column)
|
| 90 |
+
# true_negatives = np.sum(np.delete(np.delete(cm, idx, axis=0), idx, axis=1))
|
| 91 |
+
#
|
| 92 |
+
# # True positives are all the samples of our current GT class that were predicted as such
|
| 93 |
+
# true_positives = cm[idx, idx]
|
| 94 |
+
#
|
| 95 |
+
# # The accuracy for the current class is the ratio between correct predictions to all predictions
|
| 96 |
+
# per_class_accuracies[cls] = (true_positives + true_negatives) / np.sum(cm)
|
| 97 |
+
# print('acc:', per_class_accuracies)
|
| 98 |
+
# disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=label_vocab['itos'])
|
| 99 |
+
# disp.plot()
|
| 100 |
+
# plt.show()
|
| 101 |
+
return
|
| 102 |
+
else:
|
| 103 |
+
avg_micro_fscore = round(f1_score(new_labels, new_preds, average='micro', labels=list(range(1, 7))) * 100, 2)
|
| 104 |
+
avg_macro_fscore = round(f1_score(new_labels, new_preds, average='macro') * 100, 2)
|
| 105 |
+
print('test_accuracy', avg_accuracy)
|
| 106 |
+
print('test_micro_f1', avg_micro_fscore)
|
| 107 |
+
print('test_macro_f1', avg_macro_fscore)
|
| 108 |
+
return
|
| 109 |
+
|
| 110 |
+
if __name__ == '__main__':
|
| 111 |
+
|
| 112 |
+
#path = './saved_models/'
|
| 113 |
+
|
| 114 |
+
parser = argparse.ArgumentParser()
|
| 115 |
+
parser.add_argument('--bert_model_dir', type=str, default='')
|
| 116 |
+
parser.add_argument('--bert_tokenizer_dir', type=str, default='')
|
| 117 |
+
|
| 118 |
+
parser.add_argument('--state_dict_file', type=str, default='')
|
| 119 |
+
|
| 120 |
+
parser.add_argument('--bert_dim', type = int, default=1024)
|
| 121 |
+
parser.add_argument('--hidden_dim', type = int, default=300)
|
| 122 |
+
parser.add_argument('--mlp_layers', type=int, default=2, help='Number of output mlp layers.')
|
| 123 |
+
parser.add_argument('--gnn_layers', type=int, default=2, help='Number of gnn layers.')
|
| 124 |
+
parser.add_argument('--emb_dim', type=int, default=1024, help='Feature size.')
|
| 125 |
+
|
| 126 |
+
parser.add_argument('--attn_type', type=str, default='rgcn', choices=['dotprod','linear','bilinear', 'rgcn'], help='Feature size.')
|
| 127 |
+
parser.add_argument('--no_rel_attn', action='store_true', default=False, help='no relation for edges' )
|
| 128 |
+
|
| 129 |
+
parser.add_argument('--max_sent_len', type=int, default=200,
|
| 130 |
+
help='max content length for each text, if set to 0, then the max length has no constrain')
|
| 131 |
+
|
| 132 |
+
parser.add_argument('--no_cuda', action='store_true', default=False, help='does not use GPU')
|
| 133 |
+
|
| 134 |
+
parser.add_argument('--dataset_name', default='IEMOCAP', type= str, help='dataset name, IEMOCAP or MELD or DailyDialog')
|
| 135 |
+
|
| 136 |
+
parser.add_argument('--windowp', type=int, default=1,
|
| 137 |
+
help='context window size for constructing edges in graph model for past utterances')
|
| 138 |
+
|
| 139 |
+
parser.add_argument('--windowf', type=int, default=0,
|
| 140 |
+
help='context window size for constructing edges in graph model for future utterances')
|
| 141 |
+
|
| 142 |
+
parser.add_argument('--max_grad_norm', type=float, default=5.0, help='Gradient clipping.')
|
| 143 |
+
|
| 144 |
+
parser.add_argument('--lr', type=float, default=1e-3, metavar='LR', help='learning rate')
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
parser.add_argument('--dropout', type=float, default=0, metavar='dropout', help='dropout rate')
|
| 148 |
+
|
| 149 |
+
parser.add_argument('--batch_size', type=int, default=8, metavar='BS', help='batch size')
|
| 150 |
+
|
| 151 |
+
parser.add_argument('--epochs', type=int, default=20, metavar='E', help='number of epochs')
|
| 152 |
+
|
| 153 |
+
parser.add_argument('--tensorboard', action='store_true', default=False, help='Enables tensorboard log')
|
| 154 |
+
|
| 155 |
+
parser.add_argument('--nodal_att_type', type=str, default=None, choices=['global', 'past'],
|
| 156 |
+
help='type of nodal attention')
|
| 157 |
+
|
| 158 |
+
parser.add_argument('--curriculum', action='store_true', default=False, help='Enables curriculum learning')
|
| 159 |
+
|
| 160 |
+
parser.add_argument('--bucket_number', type=int, default=0, help='Number of buckets using')
|
| 161 |
+
|
| 162 |
+
args = parser.parse_args()
|
| 163 |
+
print(args)
|
| 164 |
+
|
| 165 |
+
seed_everything()
|
| 166 |
+
|
| 167 |
+
args.cuda = torch.cuda.is_available() and not args.no_cuda
|
| 168 |
+
|
| 169 |
+
if args.cuda:
|
| 170 |
+
print('Running on GPU')
|
| 171 |
+
else:
|
| 172 |
+
print('Running on CPU')
|
| 173 |
+
|
| 174 |
+
if args.tensorboard:
|
| 175 |
+
from tensorboardX import SummaryWriter
|
| 176 |
+
|
| 177 |
+
writer = SummaryWriter()
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
cuda = args.cuda
|
| 181 |
+
n_epochs = args.epochs
|
| 182 |
+
batch_size = args.batch_size
|
| 183 |
+
valid_loader, test_loader, speaker_vocab, label_vocab, person_vec = get_IEMOCAP_loaders(
|
| 184 |
+
dataset_name=args.dataset_name, batch_size=batch_size, num_workers=0, args=args)
|
| 185 |
+
n_classes = len(label_vocab['itos'])
|
| 186 |
+
|
| 187 |
+
print('building model..')
|
| 188 |
+
model = DAGERC_fushion(args, n_classes)
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
if torch.cuda.device_count() > 1:
|
| 192 |
+
print('Multi-GPU...........')
|
| 193 |
+
model = nn.DataParallel(model,device_ids = range(torch.cuda.device_count()))
|
| 194 |
+
if cuda:
|
| 195 |
+
model.cuda()
|
| 196 |
+
|
| 197 |
+
state_dict = torch.load(args.state_dict_file)
|
| 198 |
+
model.load_state_dict(state_dict)
|
| 199 |
+
evaluate(model, test_loader, cuda, args, speaker_vocab, label_vocab)
|
model.py
ADDED
|
@@ -0,0 +1,1199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import numpy as np, itertools, random, copy, math
|
| 5 |
+
from transformers import BertModel, BertConfig
|
| 6 |
+
from transformers import AutoTokenizer, AutoModelWithLMHead
|
| 7 |
+
from model_utils import *
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class BertERC(nn.Module):
|
| 11 |
+
|
| 12 |
+
def __init__(self, args, num_class):
|
| 13 |
+
super().__init__()
|
| 14 |
+
self.args = args
|
| 15 |
+
# gcn layer
|
| 16 |
+
|
| 17 |
+
self.dropout = nn.Dropout(args.dropout)
|
| 18 |
+
# bert_encoder
|
| 19 |
+
self.bert_config = BertConfig.from_json_file(args.bert_model_dir + 'config.json')
|
| 20 |
+
|
| 21 |
+
self.bert = BertModel.from_pretrained(args.home_dir + args.bert_model_dir, config = self.bert_config)
|
| 22 |
+
in_dim = args.bert_dim
|
| 23 |
+
|
| 24 |
+
# output mlp layers
|
| 25 |
+
layers = [nn.Linear(in_dim, args.hidden_dim), nn.ReLU()]
|
| 26 |
+
for _ in range(args.mlp_layers- 1):
|
| 27 |
+
layers += [nn.Linear(args.hidden_dim, args.hidden_dim), nn.ReLU()]
|
| 28 |
+
layers += [nn.Linear(args.hidden_dim, num_class)]
|
| 29 |
+
|
| 30 |
+
self.out_mlp = nn.Sequential(*layers)
|
| 31 |
+
|
| 32 |
+
def forward(self, content_ids, token_types,utterance_len,seq_len):
|
| 33 |
+
|
| 34 |
+
# the embeddings for bert
|
| 35 |
+
# if len(content_ids)>512:
|
| 36 |
+
# print('ll')
|
| 37 |
+
|
| 38 |
+
#
|
| 39 |
+
## w token_type_ids
|
| 40 |
+
# lastHidden = self.bert(content_ids, token_type_ids = token_types)[1] #(N , D)
|
| 41 |
+
## w/t token_type_ids
|
| 42 |
+
lastHidden = self.bert(content_ids)[1] #(N , D)
|
| 43 |
+
|
| 44 |
+
final_feature = self.dropout(lastHidden)
|
| 45 |
+
|
| 46 |
+
# pooling
|
| 47 |
+
|
| 48 |
+
outputs = self.out_mlp(final_feature) #(N, D)
|
| 49 |
+
|
| 50 |
+
return outputs
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
class DAGERC(nn.Module):
|
| 54 |
+
|
| 55 |
+
def __init__(self, args, num_class):
|
| 56 |
+
super().__init__()
|
| 57 |
+
self.args = args
|
| 58 |
+
# gcn layer
|
| 59 |
+
|
| 60 |
+
self.dropout = nn.Dropout(args.dropout)
|
| 61 |
+
|
| 62 |
+
self.gnn_layers = args.gnn_layers
|
| 63 |
+
|
| 64 |
+
if not args.no_rel_attn:
|
| 65 |
+
self.rel_emb = nn.Embedding(2,args.hidden_dim)
|
| 66 |
+
self.rel_attn = True
|
| 67 |
+
else:
|
| 68 |
+
self.rel_attn = False
|
| 69 |
+
|
| 70 |
+
if self.args.attn_type == 'linear':
|
| 71 |
+
gats = []
|
| 72 |
+
for _ in range(args.gnn_layers):
|
| 73 |
+
gats += [GatLinear(args.hidden_dim) if args.no_rel_attn else GatLinear_rel(args.hidden_dim)]
|
| 74 |
+
self.gather = nn.ModuleList(gats)
|
| 75 |
+
else:
|
| 76 |
+
gats = []
|
| 77 |
+
for _ in range(args.gnn_layers):
|
| 78 |
+
gats += [Gatdot(args.hidden_dim) if args.no_rel_attn else Gatdot_rel(args.hidden_dim)]
|
| 79 |
+
self.gather = nn.ModuleList(gats)
|
| 80 |
+
|
| 81 |
+
grus = []
|
| 82 |
+
for _ in range(args.gnn_layers):
|
| 83 |
+
grus += [nn.GRUCell(args.hidden_dim, args.hidden_dim)]
|
| 84 |
+
self.grus = nn.ModuleList(grus)
|
| 85 |
+
|
| 86 |
+
self.fc1 = nn.Linear(args.emb_dim, args.hidden_dim)
|
| 87 |
+
|
| 88 |
+
in_dim = args.hidden_dim * (args.gnn_layers + 1) + args.emb_dim
|
| 89 |
+
# output mlp layers
|
| 90 |
+
layers = [nn.Linear(in_dim, args.hidden_dim), nn.ReLU()]
|
| 91 |
+
for _ in range(args.mlp_layers - 1):
|
| 92 |
+
layers += [nn.Linear(args.hidden_dim, args.hidden_dim), nn.ReLU()]
|
| 93 |
+
layers += [nn.Linear(args.hidden_dim, num_class)]
|
| 94 |
+
|
| 95 |
+
self.out_mlp = nn.Sequential(*layers)
|
| 96 |
+
|
| 97 |
+
def forward(self, features, adj,s_mask):
|
| 98 |
+
'''
|
| 99 |
+
:param features: (B, N, D)
|
| 100 |
+
:param adj: (B, N, N)
|
| 101 |
+
:param s_mask: (B, N, N)
|
| 102 |
+
:return:
|
| 103 |
+
'''
|
| 104 |
+
num_utter = features.size()[1]
|
| 105 |
+
if self.rel_attn:
|
| 106 |
+
rel_ft = self.rel_emb(s_mask) # (B, N, N, D)
|
| 107 |
+
|
| 108 |
+
H0 = F.relu(self.fc1(features)) # (B, N, D)
|
| 109 |
+
H = [H0]
|
| 110 |
+
for l in range(self.args.gnn_layers):
|
| 111 |
+
H1 = self.grus[l](H[l][:,0,:]).unsqueeze(1) # (B, 1, D)
|
| 112 |
+
for i in range(1, num_utter):
|
| 113 |
+
if not self.rel_attn:
|
| 114 |
+
_, M = self.gather[l](H[l][:,i,:], H1, H1, adj[:,i,:i])
|
| 115 |
+
else:
|
| 116 |
+
_, M = self.gather[l](H[l][:, i, :], H1, H1, adj[:, i, :i], rel_ft[:, i, :i, :])
|
| 117 |
+
H1 = torch.cat((H1 , self.grus[l](H[l][:,i,:], M).unsqueeze(1)), dim = 1)
|
| 118 |
+
# print('H1', H1.size())
|
| 119 |
+
# print('----------------------------------------------------')
|
| 120 |
+
H.append(H1)
|
| 121 |
+
H0 = H1
|
| 122 |
+
H.append(features)
|
| 123 |
+
H = torch.cat(H, dim = 2) #(B, N, l*D)
|
| 124 |
+
logits = self.out_mlp(H)
|
| 125 |
+
return logits
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
class DAGERC_fushion(nn.Module):
|
| 130 |
+
|
| 131 |
+
def __init__(self, args, num_class):
|
| 132 |
+
super().__init__()
|
| 133 |
+
self.args = args
|
| 134 |
+
# gcn layer
|
| 135 |
+
|
| 136 |
+
self.dropout = nn.Dropout(args.dropout)
|
| 137 |
+
|
| 138 |
+
self.gnn_layers = args.gnn_layers
|
| 139 |
+
|
| 140 |
+
if not args.no_rel_attn:
|
| 141 |
+
self.rel_attn = True
|
| 142 |
+
else:
|
| 143 |
+
self.rel_attn = False
|
| 144 |
+
|
| 145 |
+
if self.args.attn_type == 'linear':
|
| 146 |
+
gats = []
|
| 147 |
+
for _ in range(args.gnn_layers):
|
| 148 |
+
gats += [GatLinear(args.hidden_dim) if args.no_rel_attn else GatLinear_rel(args.hidden_dim)]
|
| 149 |
+
self.gather = nn.ModuleList(gats)
|
| 150 |
+
elif self.args.attn_type == 'dotprod':
|
| 151 |
+
gats = []
|
| 152 |
+
for _ in range(args.gnn_layers):
|
| 153 |
+
gats += [GatDot(args.hidden_dim) if args.no_rel_attn else GatDot_rel(args.hidden_dim)]
|
| 154 |
+
self.gather = nn.ModuleList(gats)
|
| 155 |
+
elif self.args.attn_type == 'rgcn':
|
| 156 |
+
gats = []
|
| 157 |
+
for _ in range(args.gnn_layers):
|
| 158 |
+
# gats += [GAT_dialoggcn(args.hidden_dim)]
|
| 159 |
+
gats += [GAT_dialoggcn_v1(args.hidden_dim)]
|
| 160 |
+
self.gather = nn.ModuleList(gats)
|
| 161 |
+
|
| 162 |
+
grus_c = []
|
| 163 |
+
for _ in range(args.gnn_layers):
|
| 164 |
+
grus_c += [nn.GRUCell(args.hidden_dim, args.hidden_dim)]
|
| 165 |
+
self.grus_c = nn.ModuleList(grus_c)
|
| 166 |
+
|
| 167 |
+
grus_p = []
|
| 168 |
+
for _ in range(args.gnn_layers):
|
| 169 |
+
grus_p += [nn.GRUCell(args.hidden_dim, args.hidden_dim)]
|
| 170 |
+
self.grus_p = nn.ModuleList(grus_p)
|
| 171 |
+
|
| 172 |
+
fcs = []
|
| 173 |
+
for _ in range(args.gnn_layers):
|
| 174 |
+
fcs += [nn.Linear(args.hidden_dim * 2, args.hidden_dim)]
|
| 175 |
+
self.fcs = nn.ModuleList(fcs)
|
| 176 |
+
|
| 177 |
+
self.fc1 = nn.Linear(args.emb_dim, args.hidden_dim)
|
| 178 |
+
|
| 179 |
+
self.nodal_att_type = args.nodal_att_type
|
| 180 |
+
|
| 181 |
+
in_dim = args.hidden_dim * (args.gnn_layers + 1) + args.emb_dim
|
| 182 |
+
|
| 183 |
+
# output mlp layers
|
| 184 |
+
layers = [nn.Linear(in_dim, args.hidden_dim), nn.ReLU()]
|
| 185 |
+
for _ in range(args.mlp_layers - 1):
|
| 186 |
+
layers += [nn.Linear(args.hidden_dim, args.hidden_dim), nn.ReLU()]
|
| 187 |
+
layers += [self.dropout]
|
| 188 |
+
layers += [nn.Linear(args.hidden_dim, num_class)]
|
| 189 |
+
|
| 190 |
+
self.out_mlp = nn.Sequential(*layers)
|
| 191 |
+
|
| 192 |
+
self.attentive_node_features = attentive_node_features(in_dim)
|
| 193 |
+
|
| 194 |
+
def forward(self, features, adj,s_mask,s_mask_onehot, lengths):
|
| 195 |
+
'''
|
| 196 |
+
:param features: (B, N, D)
|
| 197 |
+
:param adj: (B, N, N)
|
| 198 |
+
:param s_mask: (B, N, N)
|
| 199 |
+
:param s_mask_onehot: (B, N, N, 2)
|
| 200 |
+
:return:
|
| 201 |
+
'''
|
| 202 |
+
num_utter = features.size()[1]
|
| 203 |
+
|
| 204 |
+
H0 = F.relu(self.fc1(features))
|
| 205 |
+
# H0 = self.dropout(H0)
|
| 206 |
+
H = [H0]
|
| 207 |
+
for l in range(self.args.gnn_layers):
|
| 208 |
+
C = self.grus_c[l](H[l][:,0,:]).unsqueeze(1)
|
| 209 |
+
M = torch.zeros_like(C).squeeze(1)
|
| 210 |
+
# P = M.unsqueeze(1)
|
| 211 |
+
P = self.grus_p[l](M, H[l][:,0,:]).unsqueeze(1)
|
| 212 |
+
#H1 = F.relu(self.fcs[l](torch.cat((C,P) , dim = 2)))
|
| 213 |
+
#H1 = F.relu(C+P)
|
| 214 |
+
H1 = C+P
|
| 215 |
+
for i in range(1, num_utter):
|
| 216 |
+
# print(i,num_utter)
|
| 217 |
+
if self.args.attn_type == 'rgcn':
|
| 218 |
+
_, M = self.gather[l](H[l][:,i,:], H1, H1, adj[:,i,:i], s_mask[:,i,:i])
|
| 219 |
+
# _, M = self.gather[l](H[l][:,i,:], H1, H1, adj[:,i,:i], s_mask_onehot[:,i,:i,:])
|
| 220 |
+
else:
|
| 221 |
+
if not self.rel_attn:
|
| 222 |
+
_, M = self.gather[l](H[l][:,i,:], H1, H1, adj[:,i,:i])
|
| 223 |
+
else:
|
| 224 |
+
_, M = self.gather[l](H[l][:,i,:], H1, H1, adj[:,i,:i], s_mask[:, i, :i])
|
| 225 |
+
|
| 226 |
+
C = self.grus_c[l](H[l][:,i,:], M).unsqueeze(1)
|
| 227 |
+
P = self.grus_p[l](M, H[l][:,i,:]).unsqueeze(1)
|
| 228 |
+
# P = M.unsqueeze(1)
|
| 229 |
+
#H_temp = F.relu(self.fcs[l](torch.cat((C,P) , dim = 2)))
|
| 230 |
+
#H_temp = F.relu(C+P)
|
| 231 |
+
H_temp = C+P
|
| 232 |
+
H1 = torch.cat((H1 , H_temp), dim = 1)
|
| 233 |
+
# print('H1', H1.size())
|
| 234 |
+
# print('----------------------------------------------------')
|
| 235 |
+
H.append(H1)
|
| 236 |
+
H.append(features)
|
| 237 |
+
|
| 238 |
+
H = torch.cat(H, dim = 2)
|
| 239 |
+
|
| 240 |
+
H = self.attentive_node_features(H,lengths,self.nodal_att_type)
|
| 241 |
+
|
| 242 |
+
logits = self.out_mlp(H)
|
| 243 |
+
|
| 244 |
+
return logits
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
#仅仅使用最后一层的short和long,concat;只用过去特征
|
| 248 |
+
#Only use the final layer's short and long features, concatenated; use only past features.
|
| 249 |
+
class DAGERC_new_1(nn.Module):
|
| 250 |
+
|
| 251 |
+
def __init__(self, args, num_class):
|
| 252 |
+
super().__init__()
|
| 253 |
+
self.args = args
|
| 254 |
+
# gcn layer
|
| 255 |
+
|
| 256 |
+
self.dropout = nn.Dropout(args.dropout)
|
| 257 |
+
|
| 258 |
+
self.gnn_layers = args.gnn_layers
|
| 259 |
+
|
| 260 |
+
if not args.no_rel_attn:
|
| 261 |
+
self.rel_attn = True
|
| 262 |
+
else:
|
| 263 |
+
self.rel_attn = False
|
| 264 |
+
|
| 265 |
+
if self.args.attn_type == 'linear':
|
| 266 |
+
gats = []
|
| 267 |
+
for _ in range(args.gnn_layers):
|
| 268 |
+
gats += [GatLinear(args.hidden_dim) if args.no_rel_attn else GatLinear_rel(args.hidden_dim)]
|
| 269 |
+
self.gather = nn.ModuleList(gats)
|
| 270 |
+
elif self.args.attn_type == 'dotprod':
|
| 271 |
+
gats = []
|
| 272 |
+
for _ in range(args.gnn_layers):
|
| 273 |
+
gats += [GatDot(args.hidden_dim) if args.no_rel_attn else GatDot_rel(args.hidden_dim)]
|
| 274 |
+
self.gather = nn.ModuleList(gats)
|
| 275 |
+
elif self.args.attn_type == 'rgcn':
|
| 276 |
+
#短距离
|
| 277 |
+
gats_short = []
|
| 278 |
+
gats_long = []
|
| 279 |
+
for _ in range(args.gnn_layers):
|
| 280 |
+
gats_short += [GAT_dialoggcn_v1(args.hidden_dim)]
|
| 281 |
+
for _ in range(args.gnn_layers):
|
| 282 |
+
gats_long += [GAT_dialoggcn_v1(args.hidden_dim)]
|
| 283 |
+
self.gather_short = nn.ModuleList(gats_short)
|
| 284 |
+
self.gather_long = nn.ModuleList(gats_long)
|
| 285 |
+
|
| 286 |
+
# 近距离 GRU
|
| 287 |
+
grus_c_short = []
|
| 288 |
+
for _ in range(args.gnn_layers):
|
| 289 |
+
grus_c_short += [nn.GRUCell(args.hidden_dim, args.hidden_dim)]
|
| 290 |
+
self.grus_c_short = nn.ModuleList(grus_c_short)
|
| 291 |
+
|
| 292 |
+
# 远距离 GRU
|
| 293 |
+
grus_c_long = []
|
| 294 |
+
for _ in range(args.gnn_layers):
|
| 295 |
+
grus_c_long += [nn.GRUCell(args.hidden_dim, args.hidden_dim)]
|
| 296 |
+
self.grus_c_long = nn.ModuleList(grus_c_long)
|
| 297 |
+
|
| 298 |
+
grus_p_short = []
|
| 299 |
+
for _ in range(args.gnn_layers):
|
| 300 |
+
grus_p_short += [nn.GRUCell(args.hidden_dim, args.hidden_dim)]
|
| 301 |
+
self.grus_p_short = nn.ModuleList(grus_p_short)
|
| 302 |
+
|
| 303 |
+
grus_p_long = []
|
| 304 |
+
for _ in range(args.gnn_layers):
|
| 305 |
+
grus_p_long += [nn.GRUCell(args.hidden_dim, args.hidden_dim)]
|
| 306 |
+
self.grus_p_long = nn.ModuleList(grus_p_long)
|
| 307 |
+
|
| 308 |
+
#近距离全链接层
|
| 309 |
+
fcs_short = []
|
| 310 |
+
for _ in range(args.gnn_layers):
|
| 311 |
+
fcs_short += [nn.Linear(args.hidden_dim * 2, args.hidden_dim)]
|
| 312 |
+
self.fcs_short = nn.ModuleList(fcs_short)
|
| 313 |
+
|
| 314 |
+
# 远距离全连接层
|
| 315 |
+
fcs_long = []
|
| 316 |
+
for _ in range(args.gnn_layers):
|
| 317 |
+
fcs_long += [nn.Linear(args.hidden_dim * 2, args.hidden_dim)]
|
| 318 |
+
self.fcs_long = nn.ModuleList(fcs_long)
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
self.fc1 = nn.Linear(args.emb_dim, args.hidden_dim)
|
| 322 |
+
|
| 323 |
+
self.nodal_att_type = args.nodal_att_type
|
| 324 |
+
|
| 325 |
+
in_dim = ((args.hidden_dim*2)+ args.emb_dim)
|
| 326 |
+
# print(in_dim)
|
| 327 |
+
# output mlp layers
|
| 328 |
+
layers = [nn.Linear(in_dim, args.hidden_dim), nn.ReLU()]
|
| 329 |
+
for _ in range(args.mlp_layers - 1):
|
| 330 |
+
layers += [nn.Linear(args.hidden_dim, args.hidden_dim), nn.ReLU()]
|
| 331 |
+
layers += [self.dropout]
|
| 332 |
+
layers += [nn.Linear(args.hidden_dim, num_class)]
|
| 333 |
+
|
| 334 |
+
self.out_mlp = nn.Sequential(*layers)
|
| 335 |
+
|
| 336 |
+
self.attentive_node_features = attentive_node_features(in_dim)
|
| 337 |
+
|
| 338 |
+
self.affine1 = nn.Parameter(torch.empty(size=((args.hidden_dim) , (args.hidden_dim) )))
|
| 339 |
+
nn.init.xavier_uniform_(self.affine1.data, gain=1.414)
|
| 340 |
+
self.affine2 = nn.Parameter(torch.empty(size=((args.hidden_dim) , (args.hidden_dim) )))
|
| 341 |
+
nn.init.xavier_uniform_(self.affine2.data, gain=1.414)
|
| 342 |
+
|
| 343 |
+
self.diff_loss = DiffLoss(args)
|
| 344 |
+
self.beta = args.diffloss
|
| 345 |
+
|
| 346 |
+
def forward(self, features, adj_1, adj_2 ,s_mask, s_mask_onehot, lengths):
|
| 347 |
+
# 检查 H1 和 H2 是否完全相等
|
| 348 |
+
are_equal = all(torch.equal(h1, h2) for h1, h2 in zip(adj_1, adj_2))
|
| 349 |
+
# print("adj1 和 adj2 是否完全相等:", are_equal)
|
| 350 |
+
# print('adj1',adj_1)
|
| 351 |
+
# print('----------------------------------------------------')
|
| 352 |
+
|
| 353 |
+
# print('adj2',adj_2)
|
| 354 |
+
# print('----------------------------------------------------')
|
| 355 |
+
|
| 356 |
+
num_utter = features.size()[1]
|
| 357 |
+
|
| 358 |
+
H0 = F.relu(self.fc1(features))
|
| 359 |
+
#print('H0', H0.size())
|
| 360 |
+
# H0 = self.dropout(H0)
|
| 361 |
+
H = [H0]
|
| 362 |
+
H_combined_short_list = []
|
| 363 |
+
#对短距离特征进行处理
|
| 364 |
+
for l in range(self.args.gnn_layers):
|
| 365 |
+
C = self.grus_c_short[l](H[l][:,0,:]).unsqueeze(1) #针对每一层的第一个节点,使用 GRU 单元更新节点特征并聚合信息。
|
| 366 |
+
M = torch.zeros_like(C).squeeze(1) #初始化一个聚合信息张量 M(全零张量),并使用它与节点特征结合生成额外的特征 P。
|
| 367 |
+
# P = M.unsqueeze(1)
|
| 368 |
+
P = self.grus_p_short[l](M, H[l][:,0,:]).unsqueeze(1) #使用 M(全零张量)和第一个节点的特征 H[l][:, 0, :] 作为输入,得到额外特征 P,形状为 (B, D)
|
| 369 |
+
#H1 = F.relu(self.fcs[l](torch.cat((C,P) , dim = 2)))
|
| 370 |
+
#H1 = F.relu(C+P)
|
| 371 |
+
H1 = C+P#将更新后的特征 C 与额外特征 P 相加,生成新的节点特征 H1,为后续层的计算做准备。
|
| 372 |
+
for i in range(1, num_utter):
|
| 373 |
+
# print(i,num_utter)
|
| 374 |
+
if self.args.attn_type == 'rgcn':
|
| 375 |
+
#将 H[l][:, i, :](当前节点特征),H1(之前节点的特征聚合结果),adj[:, i, :i](当前节点与之前节点的邻接矩阵)
|
| 376 |
+
#s_mask[:, i, :i](当前节点的掩码),得到聚合结果 M
|
| 377 |
+
_, M = self.gather_short[l](H[l][:,i,:], H1, H1, adj_1[:,i,:i], s_mask[:,i,:i])
|
| 378 |
+
# _, M = self.gather[l](H[l][:,i,:], H1, H1, adj[:,i,:i], s_mask_onehot[:,i,:i,:])
|
| 379 |
+
else:
|
| 380 |
+
if not self.rel_attn:
|
| 381 |
+
_, M = self.gather_short[l](H[l][:,i,:], H1, H1, adj_1[:,i,:i])
|
| 382 |
+
else:
|
| 383 |
+
_, M = self.gather_short[l](H[l][:,i,:], H1, H1, adj_1[:,i,:i], s_mask[:, i, :i])
|
| 384 |
+
|
| 385 |
+
#使用 GRU 单元 self.grus_c[l] 来处理当前节点的特征 H[l][:, i, :] 和聚合后的特征 M,得到新的特��� C。
|
| 386 |
+
# 这表明当前节点的特征更新与其邻居的聚合信息有关。
|
| 387 |
+
C = self.grus_c_short[l](H[l][:,i,:], M).unsqueeze(1)
|
| 388 |
+
#使用另一个 GRU 单元 self.grus_p[l] 来处理聚合特征 M 和当前节点的特征 H[l][:, i, :],得到额外的特征 P。
|
| 389 |
+
P = self.grus_p_short[l](M, H[l][:,i,:]).unsqueeze(1)
|
| 390 |
+
# P = M.unsqueeze(1)
|
| 391 |
+
#H_temp = F.relu(self.fcs[l](torch.cat((C,P) , dim = 2)))
|
| 392 |
+
#H_temp = F.relu(C+P)
|
| 393 |
+
H_temp = C+P#将更新后的特征 C 和额外特征 P 进行相加,生成新的节点特征 H_temp
|
| 394 |
+
H1 = torch.cat((H1 , H_temp), dim = 1) #将当前节点的特征 H_temp 拼接到 H1 中。
|
| 395 |
+
# print('H1', H1.size())
|
| 396 |
+
#print('----------------------------------------------------')
|
| 397 |
+
H.append(H1)
|
| 398 |
+
H_combined_short_list.append(H[l+1])
|
| 399 |
+
'''
|
| 400 |
+
下面对长距离特征进行处理 The following processes the long-distance features.
|
| 401 |
+
'''
|
| 402 |
+
H_long = [H0] # 初始化 H_long
|
| 403 |
+
H_combined_long_list = [] # 存储长距离处理的结果
|
| 404 |
+
|
| 405 |
+
# 对长距离特征进行处理
|
| 406 |
+
for l in range(self.args.gnn_layers):
|
| 407 |
+
C_long = self.grus_c_long[l](H_long[l][:,0,:]).unsqueeze(1) # 使用 GRU 更新长距离的第一个节点
|
| 408 |
+
M_long = torch.zeros_like(C_long).squeeze(1) # 初始化长距离的聚合信息张量 M_long
|
| 409 |
+
P_long = self.grus_p_long[l](M_long, H_long[l][:,0,:]).unsqueeze(1) # 生成额外的特征 P_long
|
| 410 |
+
|
| 411 |
+
H1_long = C_long + P_long # 生成新的长距离节点特征 H1_long
|
| 412 |
+
for i in range(1, num_utter):
|
| 413 |
+
# 依据不同的 attention 类型,进行特征聚合
|
| 414 |
+
if self.args.attn_type == 'rgcn':
|
| 415 |
+
_, M_long = self.gather_long[l](H_long[l][:,i,:], H1_long, H1_long, adj_2[:,i,:i], s_mask[:,i,:i])
|
| 416 |
+
else:
|
| 417 |
+
if not self.rel_attn:
|
| 418 |
+
_, M_long = self.gather_long[l](H_long[l][:,i,:], H1_long, H1_long, adj_2[:,i,:i])
|
| 419 |
+
else:
|
| 420 |
+
_, M_long = self.gather_long[l](H_long[l][:,i,:], H1_long, H1_long, adj_2[:,i,:i], s_mask[:,i,:i])
|
| 421 |
+
|
| 422 |
+
# 使用 GRU 更新当前节点的特征 C_long 和 M_long
|
| 423 |
+
C_long = self.grus_c_long[l](H_long[l][:,i,:], M_long).unsqueeze(1)
|
| 424 |
+
P_long = self.grus_p_long[l](M_long, H_long[l][:,i,:]).unsqueeze(1)
|
| 425 |
+
|
| 426 |
+
H_temp_long = C_long + P_long # 将更新后的特征 C_long 和 P_long 相加生成新特征
|
| 427 |
+
H1_long = torch.cat((H1_long, H_temp_long), dim=1) # 将特征拼接到 H1_long 中
|
| 428 |
+
H_long.append(H1_long) # 更新 H_long 列表
|
| 429 |
+
H_combined_long_list.append(H_long[l+1])
|
| 430 |
+
|
| 431 |
+
'''
|
| 432 |
+
两个通道特征都提取完毕! Both short- and long-distance channel features have been extracted!
|
| 433 |
+
'''
|
| 434 |
+
# print('H_combined_short_list',H_combined_short_list)
|
| 435 |
+
# print('H_combined_long_list',H_combined_long_list)
|
| 436 |
+
# are_equal = all(torch.equal(h1, h2) for h1, h2 in zip(H_combined_short_list, H_combined_long_list))
|
| 437 |
+
# print("H_combined_short_list 和 H_combined_long_list 是否完全相等:", are_equal)
|
| 438 |
+
# for idx, tensor in enumerate(H_combined_short_list):
|
| 439 |
+
# print(f"H_combined_short_list[{idx}] shape: {tensor.shape}")
|
| 440 |
+
H_final = []
|
| 441 |
+
# print("H2 shape:", H2.shape)
|
| 442 |
+
# 计算差异正则化损失
|
| 443 |
+
diff_loss = 0
|
| 444 |
+
for l in range(self.args.gnn_layers):
|
| 445 |
+
# print('周期:', l)
|
| 446 |
+
HShort_prime = H_combined_short_list[l]
|
| 447 |
+
HLong_prime = H_combined_long_list[l]
|
| 448 |
+
# print("HShort_prime:", HShort_prime)
|
| 449 |
+
# print("HLong_prime:", HLong_prime)
|
| 450 |
+
# print("HShort_prime shape:", HShort_prime.shape)
|
| 451 |
+
# print("HLong_prime shape:", HLong_prime.shape)
|
| 452 |
+
diff_loss = self.diff_loss(HShort_prime, HLong_prime) + diff_loss
|
| 453 |
+
# print("diff_loss:", diff_loss)
|
| 454 |
+
# print(diff_loss.item())
|
| 455 |
+
# 互交叉注意力机制
|
| 456 |
+
A1 = F.softmax(torch.bmm(torch.matmul(HShort_prime, self.affine1), torch.transpose(HLong_prime, 1, 2)), dim=2)
|
| 457 |
+
A2 = F.softmax(torch.bmm(torch.matmul(HLong_prime, self.affine2), torch.transpose(HShort_prime, 1, 2)), dim=2)
|
| 458 |
+
|
| 459 |
+
HShort_prime_new = torch.bmm(A1, HLong_prime) # 更新的短时特征
|
| 460 |
+
HLong_prime_new = torch.bmm(A2, HShort_prime) # 更新的长时特征
|
| 461 |
+
|
| 462 |
+
HShort_prime_out = self.dropout(HShort_prime_new) if l < self.args.gnn_layers - 1 else HShort_prime_new
|
| 463 |
+
HLong_prime_out = self.dropout(HLong_prime_new) if l <self.args.gnn_layers - 1 else HLong_prime_new
|
| 464 |
+
|
| 465 |
+
H_final.append(HShort_prime_out)
|
| 466 |
+
H_final.append(HLong_prime_out)
|
| 467 |
+
H_final.append(features)
|
| 468 |
+
|
| 469 |
+
H_final = torch.cat([H_final[-3],H_final[-2],H_final[-1]], dim = 2)
|
| 470 |
+
# print("H shape:", H.shape)
|
| 471 |
+
# print("H:", H.shape)
|
| 472 |
+
# print("H_final shape after cat:", H_final.shape)
|
| 473 |
+
H_final = self.attentive_node_features(H_final,lengths,self.nodal_att_type)
|
| 474 |
+
# print("H_final shape after attentive_node_features:", H_final.shape)
|
| 475 |
+
logits = self.out_mlp(H_final)
|
| 476 |
+
# print(diff_loss)
|
| 477 |
+
return logits, self.beta * diff_loss
|
| 478 |
+
|
| 479 |
+
|
| 480 |
+
#仅仅使用最后一层的short和long,concat;使用了过去和未来双特征
|
| 481 |
+
#Only the final-layer short and long features are used and concatenated; both past and future features are utilized.
|
| 482 |
+
class DAGERC_new_2(nn.Module):
|
| 483 |
+
|
| 484 |
+
def __init__(self, args, num_class):
|
| 485 |
+
super().__init__()
|
| 486 |
+
self.args = args
|
| 487 |
+
# gcn layer
|
| 488 |
+
|
| 489 |
+
self.dropout = nn.Dropout(args.dropout)
|
| 490 |
+
|
| 491 |
+
self.gnn_layers = args.gnn_layers
|
| 492 |
+
|
| 493 |
+
if not args.no_rel_attn:
|
| 494 |
+
self.rel_attn = True
|
| 495 |
+
else:
|
| 496 |
+
self.rel_attn = False
|
| 497 |
+
|
| 498 |
+
if self.args.attn_type == 'linear':
|
| 499 |
+
gats = []
|
| 500 |
+
for _ in range(args.gnn_layers):
|
| 501 |
+
gats += [GatLinear(args.hidden_dim) if args.no_rel_attn else GatLinear_rel(args.hidden_dim)]
|
| 502 |
+
self.gather = nn.ModuleList(gats)
|
| 503 |
+
elif self.args.attn_type == 'dotprod':
|
| 504 |
+
gats = []
|
| 505 |
+
for _ in range(args.gnn_layers):
|
| 506 |
+
gats += [GatDot(args.hidden_dim) if args.no_rel_attn else GatDot_rel(args.hidden_dim)]
|
| 507 |
+
self.gather = nn.ModuleList(gats)
|
| 508 |
+
elif self.args.attn_type == 'rgcn':
|
| 509 |
+
#短距离
|
| 510 |
+
gats_short = []
|
| 511 |
+
gats_long = []
|
| 512 |
+
for _ in range(args.gnn_layers):
|
| 513 |
+
gats_short += [GAT_dialoggcn_v1(args.hidden_dim)]
|
| 514 |
+
for _ in range(args.gnn_layers):
|
| 515 |
+
gats_long += [GAT_dialoggcn_v1(args.hidden_dim)]
|
| 516 |
+
self.gather_short = nn.ModuleList(gats_short)
|
| 517 |
+
self.gather_long = nn.ModuleList(gats_long)
|
| 518 |
+
|
| 519 |
+
# 近距离 GRU
|
| 520 |
+
grus_c_short = []
|
| 521 |
+
for _ in range(args.gnn_layers):
|
| 522 |
+
grus_c_short += [nn.GRUCell(args.hidden_dim, args.hidden_dim)]
|
| 523 |
+
self.grus_c_short = nn.ModuleList(grus_c_short)
|
| 524 |
+
|
| 525 |
+
# 远距离 GRU
|
| 526 |
+
grus_c_long = []
|
| 527 |
+
for _ in range(args.gnn_layers):
|
| 528 |
+
grus_c_long += [nn.GRUCell(args.hidden_dim, args.hidden_dim)]
|
| 529 |
+
self.grus_c_long = nn.ModuleList(grus_c_long)
|
| 530 |
+
|
| 531 |
+
grus_p_short = []
|
| 532 |
+
for _ in range(args.gnn_layers):
|
| 533 |
+
grus_p_short += [nn.GRUCell(args.hidden_dim, args.hidden_dim)]
|
| 534 |
+
self.grus_p_short = nn.ModuleList(grus_p_short)
|
| 535 |
+
|
| 536 |
+
grus_p_long = []
|
| 537 |
+
for _ in range(args.gnn_layers):
|
| 538 |
+
grus_p_long += [nn.GRUCell(args.hidden_dim, args.hidden_dim)]
|
| 539 |
+
self.grus_p_long = nn.ModuleList(grus_p_long)
|
| 540 |
+
|
| 541 |
+
#近距离全链接层
|
| 542 |
+
fcs_short = []
|
| 543 |
+
for _ in range(args.gnn_layers):
|
| 544 |
+
fcs_short += [nn.Linear(args.hidden_dim * 2, args.hidden_dim)]
|
| 545 |
+
self.fcs_short = nn.ModuleList(fcs_short)
|
| 546 |
+
|
| 547 |
+
# 远距离全连接层
|
| 548 |
+
fcs_long = []
|
| 549 |
+
for _ in range(args.gnn_layers):
|
| 550 |
+
fcs_long += [nn.Linear(args.hidden_dim * 2, args.hidden_dim)]
|
| 551 |
+
self.fcs_long = nn.ModuleList(fcs_long)
|
| 552 |
+
|
| 553 |
+
|
| 554 |
+
self.fc1 = nn.Linear(args.emb_dim, args.hidden_dim)
|
| 555 |
+
|
| 556 |
+
self.nodal_att_type = args.nodal_att_type
|
| 557 |
+
|
| 558 |
+
in_dim = ((args.hidden_dim*2)*2 + args.emb_dim)
|
| 559 |
+
# print(in_dim)
|
| 560 |
+
# output mlp layers
|
| 561 |
+
layers = [nn.Linear(in_dim, args.hidden_dim), nn.ReLU()]
|
| 562 |
+
for _ in range(args.mlp_layers - 1):
|
| 563 |
+
layers += [nn.Linear(args.hidden_dim, args.hidden_dim), nn.ReLU()]
|
| 564 |
+
layers += [self.dropout]
|
| 565 |
+
layers += [nn.Linear(args.hidden_dim, num_class)]
|
| 566 |
+
|
| 567 |
+
self.out_mlp = nn.Sequential(*layers)
|
| 568 |
+
|
| 569 |
+
self.attentive_node_features = attentive_node_features(in_dim)
|
| 570 |
+
|
| 571 |
+
self.affine1 = nn.Parameter(torch.empty(size=((args.hidden_dim*2) , (args.hidden_dim*2) )))
|
| 572 |
+
nn.init.xavier_uniform_(self.affine1.data, gain=1.414)
|
| 573 |
+
self.affine2 = nn.Parameter(torch.empty(size=((args.hidden_dim*2) , (args.hidden_dim*2) )))
|
| 574 |
+
nn.init.xavier_uniform_(self.affine2.data, gain=1.414)
|
| 575 |
+
|
| 576 |
+
self.diff_loss = DiffLoss(args)
|
| 577 |
+
self.beta = args.diffloss
|
| 578 |
+
|
| 579 |
+
def forward(self, features, adj_1, adj_2 ,s_mask, s_mask_onehot, lengths):
|
| 580 |
+
# 检查 H1 和 H2 是否完全相等
|
| 581 |
+
are_equal = all(torch.equal(h1, h2) for h1, h2 in zip(adj_1, adj_2))
|
| 582 |
+
# print("adj1 和 adj2 是否完全相等:", are_equal)
|
| 583 |
+
# print('adj1',adj_1)
|
| 584 |
+
# print('----------------------------------------------------')
|
| 585 |
+
|
| 586 |
+
# print('adj2',adj_2)
|
| 587 |
+
# print('----------------------------------------------------')
|
| 588 |
+
|
| 589 |
+
num_utter = features.size()[1]
|
| 590 |
+
|
| 591 |
+
H0 = F.relu(self.fc1(features))
|
| 592 |
+
#print('H0', H0.size())
|
| 593 |
+
# H0 = self.dropout(H0)
|
| 594 |
+
H = [H0]
|
| 595 |
+
H_combined_short_list = []
|
| 596 |
+
#对短距离特征进行处理
|
| 597 |
+
for l in range(self.args.gnn_layers):
|
| 598 |
+
C = self.grus_c_short[l](H[l][:,0,:]).unsqueeze(1) #针对每一层的第一个节点,使用 GRU 单元更新节点特征并聚合信息。
|
| 599 |
+
M = torch.zeros_like(C).squeeze(1) #初始化一个聚合信息张量 M(全零张量),并使用它与节点特征结合生成额外的特征 P。
|
| 600 |
+
# P = M.unsqueeze(1)
|
| 601 |
+
P = self.grus_p_short[l](M, H[l][:,0,:]).unsqueeze(1) #使用 M(全零张量)和第一个节点的特征 H[l][:, 0, :] 作为输入,得到额外特征 P,形状为 (B, D)
|
| 602 |
+
#H1 = F.relu(self.fcs[l](torch.cat((C,P) , dim = 2)))
|
| 603 |
+
#H1 = F.relu(C+P)
|
| 604 |
+
H1 = C+P#将更新后的特征 C 与额外特征 P 相加,生成新的节点特征 H1,为后续层的计算做准备。
|
| 605 |
+
for i in range(1, num_utter):
|
| 606 |
+
# print(i,num_utter)
|
| 607 |
+
if self.args.attn_type == 'rgcn':
|
| 608 |
+
#将 H[l][:, i, :](当前节点特征),H1(之前节点的特征聚合结果),adj[:, i, :i](当前节点与之前节点的邻接矩阵)
|
| 609 |
+
#s_mask[:, i, :i](当前节点的掩码),得到聚合结果 M
|
| 610 |
+
_, M = self.gather_short[l](H[l][:,i,:], H1, H1, adj_1[:,i,:i], s_mask[:,i,:i])
|
| 611 |
+
# _, M = self.gather[l](H[l][:,i,:], H1, H1, adj[:,i,:i], s_mask_onehot[:,i,:i,:])
|
| 612 |
+
else:
|
| 613 |
+
if not self.rel_attn:
|
| 614 |
+
_, M = self.gather_short[l](H[l][:,i,:], H1, H1, adj_1[:,i,:i])
|
| 615 |
+
else:
|
| 616 |
+
_, M = self.gather_short[l](H[l][:,i,:], H1, H1, adj_1[:,i,:i], s_mask[:, i, :i])
|
| 617 |
+
|
| 618 |
+
#使用 GRU 单元 self.grus_c[l] 来处理当前节点的特征 H[l][:, i, :] 和聚合后的特征 M,得到新的特征 C。
|
| 619 |
+
# 这表明当前节点的特征更新与其邻居的聚合信息有关。
|
| 620 |
+
C = self.grus_c_short[l](H[l][:,i,:], M).unsqueeze(1)
|
| 621 |
+
#使用另一个 GRU 单元 self.grus_p[l] 来处理聚合特征 M 和当前节点的特征 H[l][:, i, :],得到额外的特征 P。
|
| 622 |
+
P = self.grus_p_short[l](M, H[l][:,i,:]).unsqueeze(1)
|
| 623 |
+
# P = M.unsqueeze(1)
|
| 624 |
+
#H_temp = F.relu(self.fcs[l](torch.cat((C,P) , dim = 2)))
|
| 625 |
+
#H_temp = F.relu(C+P)
|
| 626 |
+
H_temp = C+P#将更新后的特征 C 和额外特征 P 进行相加,生成新的节点特征 H_temp
|
| 627 |
+
H1 = torch.cat((H1 , H_temp), dim = 1) #将当前节点的特征 H_temp 拼接到 H1 中。
|
| 628 |
+
# print('H1', H1.size())
|
| 629 |
+
#print('----------------------------------------------------')
|
| 630 |
+
H.append(H1)
|
| 631 |
+
|
| 632 |
+
# 将输入特征反转
|
| 633 |
+
# 反向特征提取
|
| 634 |
+
features_reversed = torch.flip(features, dims=[1]) # 反转特征顺序
|
| 635 |
+
adj_reversed = torch.flip(adj_1, dims=[1, 2]) # 反转邻接矩阵
|
| 636 |
+
s_mask_reversed = torch.flip(s_mask, dims=[1, 2]) # 反转掩码
|
| 637 |
+
|
| 638 |
+
H0_reversed = F.relu(self.fc1(features_reversed))
|
| 639 |
+
H_reversed = [H0_reversed]
|
| 640 |
+
|
| 641 |
+
for l in range(self.args.gnn_layers):
|
| 642 |
+
C = self.grus_c_short[l](H_reversed[l][:, 0, :]).unsqueeze(1)
|
| 643 |
+
M = torch.zeros_like(C).squeeze(1)
|
| 644 |
+
P = self.grus_p_short[l](M, H_reversed[l][:, 0, :]).unsqueeze(1)
|
| 645 |
+
H1_reversed = C + P
|
| 646 |
+
|
| 647 |
+
for i in range(1, num_utter):
|
| 648 |
+
if self.args.attn_type == 'rgcn':
|
| 649 |
+
_, M = self.gather_short[l](H_reversed[l][:, i, :], H1_reversed, H1_reversed, adj_reversed[:, i, :i], s_mask_reversed[:, i, :i])
|
| 650 |
+
else:
|
| 651 |
+
if not self.rel_attn:
|
| 652 |
+
_, M = self.gather_short[l](H_reversed[l][:, i, :], H1_reversed, H1_reversed, adj_reversed[:, i, :i])
|
| 653 |
+
else:
|
| 654 |
+
_, M = self.gather_short[l](H_reversed[l][:, i, :], H1_reversed, H1_reversed, adj_reversed[:, i, :i], s_mask_reversed[:, i, :i])
|
| 655 |
+
|
| 656 |
+
C = self.grus_c_short[l](H_reversed[l][:, i, :], M).unsqueeze(1)
|
| 657 |
+
P = self.grus_p_short[l](M, H_reversed[l][:, i, :]).unsqueeze(1)
|
| 658 |
+
H_temp_reversed = C + P
|
| 659 |
+
H1_reversed = torch.cat((H1_reversed, H_temp_reversed), dim=1)
|
| 660 |
+
H_reversed.append(H1_reversed)
|
| 661 |
+
H_combined = torch.cat((H[l+1], H_reversed[l+1]), dim=2) # 在第二维度拼接
|
| 662 |
+
H_combined_short_list.append(H_combined) # 将拼接后的结果添加到新列表中
|
| 663 |
+
|
| 664 |
+
'''
|
| 665 |
+
下面对长距离特征进行处理 The following processes the long-distance features.
|
| 666 |
+
'''
|
| 667 |
+
H_long = [H0] # 初始化 H_long
|
| 668 |
+
H_combined_long_list = [] # 存储长距离处理的结果
|
| 669 |
+
|
| 670 |
+
# 对长距离特征进行处理
|
| 671 |
+
for l in range(self.args.gnn_layers):
|
| 672 |
+
C_long = self.grus_c_long[l](H_long[l][:,0,:]).unsqueeze(1) # 使用 GRU 更新长距离的第一个��点
|
| 673 |
+
M_long = torch.zeros_like(C_long).squeeze(1) # 初始化长距离的聚合信息张量 M_long
|
| 674 |
+
P_long = self.grus_p_long[l](M_long, H_long[l][:,0,:]).unsqueeze(1) # 生成额外的特征 P_long
|
| 675 |
+
|
| 676 |
+
H1_long = C_long + P_long # 生成新的长距离节点特征 H1_long
|
| 677 |
+
for i in range(1, num_utter):
|
| 678 |
+
# 依据不同的 attention 类型,进行特征聚合
|
| 679 |
+
if self.args.attn_type == 'rgcn':
|
| 680 |
+
_, M_long = self.gather_long[l](H_long[l][:,i,:], H1_long, H1_long, adj_2[:,i,:i], s_mask[:,i,:i])
|
| 681 |
+
else:
|
| 682 |
+
if not self.rel_attn:
|
| 683 |
+
_, M_long = self.gather_long[l](H_long[l][:,i,:], H1_long, H1_long, adj_2[:,i,:i])
|
| 684 |
+
else:
|
| 685 |
+
_, M_long = self.gather_long[l](H_long[l][:,i,:], H1_long, H1_long, adj_2[:,i,:i], s_mask[:,i,:i])
|
| 686 |
+
|
| 687 |
+
# 使用 GRU 更新当前节点的特征 C_long 和 M_long
|
| 688 |
+
C_long = self.grus_c_long[l](H_long[l][:,i,:], M_long).unsqueeze(1)
|
| 689 |
+
P_long = self.grus_p_long[l](M_long, H_long[l][:,i,:]).unsqueeze(1)
|
| 690 |
+
|
| 691 |
+
H_temp_long = C_long + P_long # 将更新后的特征 C_long 和 P_long 相加生成新特征
|
| 692 |
+
H1_long = torch.cat((H1_long, H_temp_long), dim=1) # 将特征拼接到 H1_long 中
|
| 693 |
+
H_long.append(H1_long) # 更新 H_long 列表
|
| 694 |
+
|
| 695 |
+
# 反转特征顺序,进行逆向长距离特征提取
|
| 696 |
+
features_reversed_long = torch.flip(features, dims=[1]) # 反转特征顺序
|
| 697 |
+
adj_reversed_long = torch.flip(adj_2, dims=[1, 2]) # 反转长距离邻接矩阵
|
| 698 |
+
s_mask_reversed_long = torch.flip(s_mask, dims=[1, 2]) # 反转掩码
|
| 699 |
+
|
| 700 |
+
H0_reversed_long = F.relu(self.fc1(features_reversed_long))
|
| 701 |
+
H_reversed_long = [H0_reversed_long]
|
| 702 |
+
|
| 703 |
+
for l in range(self.args.gnn_layers):
|
| 704 |
+
C_long = self.grus_c_long[l](H_reversed_long[l][:, 0, :]).unsqueeze(1)
|
| 705 |
+
M_long = torch.zeros_like(C_long).squeeze(1)
|
| 706 |
+
P_long = self.grus_p_long[l](M_long, H_reversed_long[l][:, 0, :]).unsqueeze(1)
|
| 707 |
+
H1_reversed_long = C_long + P_long
|
| 708 |
+
|
| 709 |
+
for i in range(1, num_utter):
|
| 710 |
+
if self.args.attn_type == 'rgcn':
|
| 711 |
+
_, M_long = self.gather_long[l](H_reversed_long[l][:, i, :], H1_reversed_long, H1_reversed_long, adj_reversed_long[:, i, :i], s_mask_reversed_long[:, i, :i])
|
| 712 |
+
else:
|
| 713 |
+
if not self.rel_attn:
|
| 714 |
+
_, M_long = self.gather_long[l](H_reversed_long[l][:, i, :], H1_reversed_long, H1_reversed_long, adj_reversed_long[:, i, :i])
|
| 715 |
+
else:
|
| 716 |
+
_, M_long = self.gather_long[l](H_reversed_long[l][:, i, :], H1_reversed_long, H1_reversed_long, adj_reversed_long[:, i, :i], s_mask_reversed_long[:, i, :i])
|
| 717 |
+
|
| 718 |
+
C_long = self.grus_c_long[l](H_reversed_long[l][:, i, :], M_long).unsqueeze(1)
|
| 719 |
+
P_long = self.grus_p_long[l](M_long, H_reversed_long[l][:, i, :]).unsqueeze(1)
|
| 720 |
+
H_temp_reversed_long = C_long + P_long
|
| 721 |
+
H1_reversed_long = torch.cat((H1_reversed_long, H_temp_reversed_long), dim=1)
|
| 722 |
+
H_reversed_long.append(H1_reversed_long)
|
| 723 |
+
|
| 724 |
+
# 将正向和逆向的长距离特征进行拼接
|
| 725 |
+
H_combined_long = torch.cat((H_long[l+1], H_reversed_long[l+1]), dim=2)
|
| 726 |
+
H_combined_long_list.append(H_combined_long)
|
| 727 |
+
|
| 728 |
+
'''
|
| 729 |
+
两个通道特征都提取完毕! Both short- and long-distance channel features have been extracted!
|
| 730 |
+
'''
|
| 731 |
+
# print('H_combined_short_list',H_combined_short_list)
|
| 732 |
+
# print('H_combined_long_list',H_combined_long_list)
|
| 733 |
+
# are_equal = all(torch.equal(h1, h2) for h1, h2 in zip(H_combined_short_list, H_combined_long_list))
|
| 734 |
+
# print("H_combined_short_list 和 H_combined_long_list 是否完全相等:", are_equal)
|
| 735 |
+
# for idx, tensor in enumerate(H_combined_short_list):
|
| 736 |
+
# print(f"H_combined_short_list[{idx}] shape: {tensor.shape}")
|
| 737 |
+
H_final = []
|
| 738 |
+
# print("H2 shape:", H2.shape)
|
| 739 |
+
# 计算差异正则化损失
|
| 740 |
+
diff_loss = 0
|
| 741 |
+
for l in range(self.args.gnn_layers):
|
| 742 |
+
# print('周期:', l)
|
| 743 |
+
HShort_prime = H_combined_short_list[l]
|
| 744 |
+
HLong_prime = H_combined_long_list[l]
|
| 745 |
+
print("HShort_prime:", HShort_prime)
|
| 746 |
+
print("HLong_prime:", HLong_prime)
|
| 747 |
+
print("HShort_prime shape:", HShort_prime.shape)
|
| 748 |
+
print("HLong_prime shape:", HLong_prime.shape)
|
| 749 |
+
diff_loss = self.diff_loss(HShort_prime, HLong_prime) + diff_loss
|
| 750 |
+
# print("diff_loss:", diff_loss)
|
| 751 |
+
# print(diff_loss.item())
|
| 752 |
+
# 互交叉注意力机制
|
| 753 |
+
A1 = F.softmax(torch.bmm(torch.matmul(HShort_prime, self.affine1), torch.transpose(HLong_prime, 1, 2)), dim=2)
|
| 754 |
+
A2 = F.softmax(torch.bmm(torch.matmul(HLong_prime, self.affine2), torch.transpose(HShort_prime, 1, 2)), dim=2)
|
| 755 |
+
|
| 756 |
+
HShort_prime_new = torch.bmm(A1, HLong_prime) # 更新的短时特征
|
| 757 |
+
HLong_prime_new = torch.bmm(A2, HShort_prime) # 更新的长时特征
|
| 758 |
+
|
| 759 |
+
HShort_prime_out = self.dropout(HShort_prime_new) if l < self.args.gnn_layers - 1 else HShort_prime_new
|
| 760 |
+
HLong_prime_out = self.dropout(HLong_prime_new) if l <self.args.gnn_layers - 1 else HLong_prime_new
|
| 761 |
+
|
| 762 |
+
H_final.append(HShort_prime_out)
|
| 763 |
+
H_final.append(HLong_prime_out)
|
| 764 |
+
H_final.append(features)
|
| 765 |
+
|
| 766 |
+
H_final = torch.cat([H_final[-3],H_final[-2],H_final[-1]], dim = 2)
|
| 767 |
+
# print("H shape:", H.shape)
|
| 768 |
+
# print("H:", H.shape)
|
| 769 |
+
# print("H_final shape after cat:", H_final.shape)
|
| 770 |
+
H_final = self.attentive_node_features(H_final,lengths,self.nodal_att_type)
|
| 771 |
+
# print("H_final shape after attentive_node_features:", H_final.shape)
|
| 772 |
+
logits = self.out_mlp(H_final)
|
| 773 |
+
# print(diff_loss)
|
| 774 |
+
return logits, self.beta * diff_loss
|
| 775 |
+
|
| 776 |
+
|
| 777 |
+
#使用所有层的short和long,使用sum加每一层,不使用双特征融合技术
|
| 778 |
+
#All-layer short and long features are used, with a sum over each layer; dual-feature fusion is not applied.
|
| 779 |
+
class DAGERC_new_3(nn.Module):
|
| 780 |
+
|
| 781 |
+
def __init__(self, args, num_class):
|
| 782 |
+
super().__init__()
|
| 783 |
+
self.args = args
|
| 784 |
+
# gcn layer
|
| 785 |
+
|
| 786 |
+
self.dropout = nn.Dropout(args.dropout)
|
| 787 |
+
|
| 788 |
+
self.gnn_layers = args.gnn_layers
|
| 789 |
+
|
| 790 |
+
if not args.no_rel_attn:
|
| 791 |
+
self.rel_attn = True
|
| 792 |
+
else:
|
| 793 |
+
self.rel_attn = False
|
| 794 |
+
|
| 795 |
+
if self.args.attn_type == 'linear':
|
| 796 |
+
gats = []
|
| 797 |
+
for _ in range(args.gnn_layers):
|
| 798 |
+
gats += [GatLinear(args.hidden_dim) if args.no_rel_attn else GatLinear_rel(args.hidden_dim)]
|
| 799 |
+
self.gather = nn.ModuleList(gats)
|
| 800 |
+
elif self.args.attn_type == 'dotprod':
|
| 801 |
+
gats = []
|
| 802 |
+
for _ in range(args.gnn_layers):
|
| 803 |
+
gats += [GatDot(args.hidden_dim) if args.no_rel_attn else GatDot_rel(args.hidden_dim)]
|
| 804 |
+
self.gather = nn.ModuleList(gats)
|
| 805 |
+
elif self.args.attn_type == 'rgcn':
|
| 806 |
+
#短距离
|
| 807 |
+
gats_short = []
|
| 808 |
+
gats_long = []
|
| 809 |
+
for _ in range(args.gnn_layers):
|
| 810 |
+
gats_short += [GAT_dialoggcn_v1(args.hidden_dim)]
|
| 811 |
+
for _ in range(args.gnn_layers):
|
| 812 |
+
gats_long += [GAT_dialoggcn_v1(args.hidden_dim)]
|
| 813 |
+
self.gather_short = nn.ModuleList(gats_short)
|
| 814 |
+
self.gather_long = nn.ModuleList(gats_long)
|
| 815 |
+
|
| 816 |
+
# 近距离 GRU
|
| 817 |
+
grus_c_short = []
|
| 818 |
+
for _ in range(args.gnn_layers):
|
| 819 |
+
grus_c_short += [nn.GRUCell(args.hidden_dim, args.hidden_dim)]
|
| 820 |
+
self.grus_c_short = nn.ModuleList(grus_c_short)
|
| 821 |
+
|
| 822 |
+
# 远距离 GRU
|
| 823 |
+
grus_c_long = []
|
| 824 |
+
for _ in range(args.gnn_layers):
|
| 825 |
+
grus_c_long += [nn.GRUCell(args.hidden_dim, args.hidden_dim)]
|
| 826 |
+
self.grus_c_long = nn.ModuleList(grus_c_long)
|
| 827 |
+
|
| 828 |
+
grus_p_short = []
|
| 829 |
+
for _ in range(args.gnn_layers):
|
| 830 |
+
grus_p_short += [nn.GRUCell(args.hidden_dim, args.hidden_dim)]
|
| 831 |
+
self.grus_p_short = nn.ModuleList(grus_p_short)
|
| 832 |
+
|
| 833 |
+
grus_p_long = []
|
| 834 |
+
for _ in range(args.gnn_layers):
|
| 835 |
+
grus_p_long += [nn.GRUCell(args.hidden_dim, args.hidden_dim)]
|
| 836 |
+
self.grus_p_long = nn.ModuleList(grus_p_long)
|
| 837 |
+
|
| 838 |
+
#近距离全链接层
|
| 839 |
+
fcs_short = []
|
| 840 |
+
for _ in range(args.gnn_layers):
|
| 841 |
+
fcs_short += [nn.Linear(args.hidden_dim * 2, args.hidden_dim)]
|
| 842 |
+
self.fcs_short = nn.ModuleList(fcs_short)
|
| 843 |
+
|
| 844 |
+
# 远距离全连接层
|
| 845 |
+
fcs_long = []
|
| 846 |
+
for _ in range(args.gnn_layers):
|
| 847 |
+
fcs_long += [nn.Linear(args.hidden_dim * 2, args.hidden_dim)]
|
| 848 |
+
self.fcs_long = nn.ModuleList(fcs_long)
|
| 849 |
+
|
| 850 |
+
|
| 851 |
+
self.fc1 = nn.Linear(args.emb_dim, args.hidden_dim)
|
| 852 |
+
|
| 853 |
+
self.nodal_att_type = args.nodal_att_type
|
| 854 |
+
|
| 855 |
+
in_dim = (args.hidden_dim * (args.gnn_layers + 1)) + args.emb_dim
|
| 856 |
+
# print(in_dim)
|
| 857 |
+
# output mlp layers
|
| 858 |
+
layers = [nn.Linear(in_dim, args.hidden_dim), nn.ReLU()]
|
| 859 |
+
for _ in range(args.mlp_layers - 1):
|
| 860 |
+
layers += [nn.Linear(args.hidden_dim, args.hidden_dim), nn.ReLU()]
|
| 861 |
+
layers += [self.dropout]
|
| 862 |
+
layers += [nn.Linear(args.hidden_dim, num_class)]
|
| 863 |
+
|
| 864 |
+
self.out_mlp = nn.Sequential(*layers)
|
| 865 |
+
|
| 866 |
+
self.attentive_node_features = attentive_node_features(in_dim)
|
| 867 |
+
|
| 868 |
+
def forward(self, features, adj_1, adj_2 ,s_mask, s_mask_onehot, lengths):
|
| 869 |
+
# 检查 H1 和 H2 是否完全相等
|
| 870 |
+
are_equal = all(torch.equal(h1, h2) for h1, h2 in zip(adj_1, adj_2))
|
| 871 |
+
# print("adj1 和 adj2 是否完全相等:", are_equal)
|
| 872 |
+
# print('adj1',adj_1)
|
| 873 |
+
# print('----------------------------------------------------')
|
| 874 |
+
|
| 875 |
+
# print('adj2',adj_2)
|
| 876 |
+
# print('----------------------------------------------------')
|
| 877 |
+
|
| 878 |
+
num_utter = features.size()[1]
|
| 879 |
+
|
| 880 |
+
H0 = F.relu(self.fc1(features))
|
| 881 |
+
#print('H0', H0.size())
|
| 882 |
+
# H0 = self.dropout(H0)
|
| 883 |
+
H = [H0]
|
| 884 |
+
H_combined_short_list = []
|
| 885 |
+
#对短距离特征进行处理
|
| 886 |
+
for l in range(self.args.gnn_layers):
|
| 887 |
+
C = self.grus_c_short[l](H[l][:,0,:]).unsqueeze(1) #针对每一层的第一个节点,使用 GRU 单元更新节点特征并聚合信息。
|
| 888 |
+
M = torch.zeros_like(C).squeeze(1) #初始化一个聚合信息张量 M(全零张量),并使用它与节点特征结合生成额外的特征 P。
|
| 889 |
+
# P = M.unsqueeze(1)
|
| 890 |
+
P = self.grus_p_short[l](M, H[l][:,0,:]).unsqueeze(1) #使用 M(全零张量)和第一个节点的特征 H[l][:, 0, :] 作为输入,得到额外特征 P,形状为 (B, D)
|
| 891 |
+
#H1 = F.relu(self.fcs[l](torch.cat((C,P) , dim = 2)))
|
| 892 |
+
#H1 = F.relu(C+P)
|
| 893 |
+
H1 = C+P#将更新后的特征 C 与额外特征 P 相加,生成新的节点特征 H1,为后续层的计算做准备。
|
| 894 |
+
for i in range(1, num_utter):
|
| 895 |
+
# print(i,num_utter)
|
| 896 |
+
if self.args.attn_type == 'rgcn':
|
| 897 |
+
#将 H[l][:, i, :](当前节点特征),H1(之前节点的特征聚合结果),adj[:, i, :i](当前节点与之前节点的邻接矩阵)
|
| 898 |
+
#s_mask[:, i, :i](当前节点的掩码),得到聚合结果 M
|
| 899 |
+
_, M = self.gather_short[l](H[l][:,i,:], H1, H1, adj_1[:,i,:i], s_mask[:,i,:i])
|
| 900 |
+
# _, M = self.gather[l](H[l][:,i,:], H1, H1, adj[:,i,:i], s_mask_onehot[:,i,:i,:])
|
| 901 |
+
else:
|
| 902 |
+
if not self.rel_attn:
|
| 903 |
+
_, M = self.gather_short[l](H[l][:,i,:], H1, H1, adj_1[:,i,:i])
|
| 904 |
+
else:
|
| 905 |
+
_, M = self.gather_short[l](H[l][:,i,:], H1, H1, adj_1[:,i,:i], s_mask[:, i, :i])
|
| 906 |
+
|
| 907 |
+
#使用 GRU 单元 self.grus_c[l] 来处理当前节点的特征 H[l][:, i, :] 和聚合后的特征 M,得到新的特征 C。
|
| 908 |
+
# 这表明当前节点的特征更新与其邻居的聚合信息有关。
|
| 909 |
+
C = self.grus_c_short[l](H[l][:,i,:], M).unsqueeze(1)
|
| 910 |
+
#使用另一个 GRU 单元 self.grus_p[l] 来处理聚合特征 M 和当前节点的特征 H[l][:, i, :],得到额外的特征 P。
|
| 911 |
+
P = self.grus_p_short[l](M, H[l][:,i,:]).unsqueeze(1)
|
| 912 |
+
# P = M.unsqueeze(1)
|
| 913 |
+
#H_temp = F.relu(self.fcs[l](torch.cat((C,P) , dim = 2)))
|
| 914 |
+
#H_temp = F.relu(C+P)
|
| 915 |
+
H_temp = C+P#将更新后的特征 C 和额外特征 P 进行相加,生成新的节点特征 H_temp
|
| 916 |
+
H1 = torch.cat((H1 , H_temp), dim = 1) #将当前节点的特征 H_temp 拼接到 H1 中。
|
| 917 |
+
# print('H1', H1.size())
|
| 918 |
+
#print('----------------------------------------------------')
|
| 919 |
+
H.append(H1)
|
| 920 |
+
'''
|
| 921 |
+
下面对长距离特征进行处理
|
| 922 |
+
'''
|
| 923 |
+
H_long = [H0] # 初始化 H_long
|
| 924 |
+
H_combined_long_list = [] # 存储长距离处理的结果
|
| 925 |
+
|
| 926 |
+
# 对长距离特征进行处理
|
| 927 |
+
for l in range(self.args.gnn_layers):
|
| 928 |
+
C_long = self.grus_c_long[l](H_long[l][:,0,:]).unsqueeze(1) # 使用 GRU 更新长距离的第一个节点
|
| 929 |
+
M_long = torch.zeros_like(C_long).squeeze(1) # 初始化长距离的聚合信息张量 M_long
|
| 930 |
+
P_long = self.grus_p_long[l](M_long, H_long[l][:,0,:]).unsqueeze(1) # 生成额外的特征 P_long
|
| 931 |
+
|
| 932 |
+
H1_long = C_long + P_long # 生成新的长距离节点特征 H1_long
|
| 933 |
+
for i in range(1, num_utter):
|
| 934 |
+
# 依据不同的 attention 类型,进行特征聚合
|
| 935 |
+
if self.args.attn_type == 'rgcn':
|
| 936 |
+
_, M_long = self.gather_long[l](H_long[l][:,i,:], H1_long, H1_long, adj_2[:,i,:i], s_mask[:,i,:i])
|
| 937 |
+
else:
|
| 938 |
+
if not self.rel_attn:
|
| 939 |
+
_, M_long = self.gather_long[l](H_long[l][:,i,:], H1_long, H1_long, adj_2[:,i,:i])
|
| 940 |
+
else:
|
| 941 |
+
_, M_long = self.gather_long[l](H_long[l][:,i,:], H1_long, H1_long, adj_2[:,i,:i], s_mask[:,i,:i])
|
| 942 |
+
|
| 943 |
+
# 使用 GRU 更新当前节点的特征 C_long 和 M_long
|
| 944 |
+
C_long = self.grus_c_long[l](H_long[l][:,i,:], M_long).unsqueeze(1)
|
| 945 |
+
P_long = self.grus_p_long[l](M_long, H_long[l][:,i,:]).unsqueeze(1)
|
| 946 |
+
|
| 947 |
+
H_temp_long = C_long + P_long # 将更新后的特征 C_long 和 P_long 相加生成新特征
|
| 948 |
+
H1_long = torch.cat((H1_long, H_temp_long), dim=1) # 将特征拼接到 H1_long 中
|
| 949 |
+
H_long.append(H1_long) # 更新 H_long 列表
|
| 950 |
+
# for i, h in enumerate(H):
|
| 951 |
+
# print(f"H[{i}] shape: {h.shape}")
|
| 952 |
+
|
| 953 |
+
H_combined = torch.cat(H, dim=2)
|
| 954 |
+
H_long_combined = torch.cat(H_long, dim=2)
|
| 955 |
+
sum_features = H_combined + H_long_combined
|
| 956 |
+
# print('sum_features Shape:', sum_features.shape)
|
| 957 |
+
# print('features Shape:', features.shape)
|
| 958 |
+
H_combined_final = torch.cat((sum_features, features), dim=2)
|
| 959 |
+
|
| 960 |
+
H_final = self.attentive_node_features(H_combined_final,lengths,self.nodal_att_type)
|
| 961 |
+
# print("H_final shape after attentive_node_features:", H_final.shape)
|
| 962 |
+
logits = self.out_mlp(H_final)
|
| 963 |
+
# print(diff_loss)
|
| 964 |
+
return logits
|
| 965 |
+
|
| 966 |
+
#使用过去的所有层的short和long,每一层都concat,使用特征融合技术。
|
| 967 |
+
#All past-layer short and long features are used; features from each layer are concatenated, and feature fusion techniques are applied.
|
| 968 |
+
class DAGERC_new_4(nn.Module):
|
| 969 |
+
|
| 970 |
+
def __init__(self, args, num_class):
|
| 971 |
+
super().__init__()
|
| 972 |
+
self.args = args
|
| 973 |
+
# gcn layer
|
| 974 |
+
|
| 975 |
+
self.dropout = nn.Dropout(args.dropout)
|
| 976 |
+
|
| 977 |
+
self.gnn_layers = args.gnn_layers
|
| 978 |
+
|
| 979 |
+
if not args.no_rel_attn:
|
| 980 |
+
self.rel_attn = True
|
| 981 |
+
else:
|
| 982 |
+
self.rel_attn = False
|
| 983 |
+
|
| 984 |
+
if self.args.attn_type == 'linear':
|
| 985 |
+
gats = []
|
| 986 |
+
for _ in range(args.gnn_layers):
|
| 987 |
+
gats += [GatLinear(args.hidden_dim) if args.no_rel_attn else GatLinear_rel(args.hidden_dim)]
|
| 988 |
+
self.gather = nn.ModuleList(gats)
|
| 989 |
+
elif self.args.attn_type == 'dotprod':
|
| 990 |
+
gats = []
|
| 991 |
+
for _ in range(args.gnn_layers):
|
| 992 |
+
gats += [GatDot(args.hidden_dim) if args.no_rel_attn else GatDot_rel(args.hidden_dim)]
|
| 993 |
+
self.gather = nn.ModuleList(gats)
|
| 994 |
+
elif self.args.attn_type == 'rgcn':
|
| 995 |
+
gats_short = []
|
| 996 |
+
gats_long = []
|
| 997 |
+
for _ in range(args.gnn_layers):
|
| 998 |
+
gats_short += [GAT_dialoggcn_v1(args.hidden_dim)]
|
| 999 |
+
for _ in range(args.gnn_layers):
|
| 1000 |
+
gats_long += [GAT_dialoggcn_v1(args.hidden_dim)]
|
| 1001 |
+
self.gather_short = nn.ModuleList(gats_short)
|
| 1002 |
+
self.gather_long = nn.ModuleList(gats_long)
|
| 1003 |
+
|
| 1004 |
+
# 近距离 GRU
|
| 1005 |
+
# short distance GRU
|
| 1006 |
+
grus_c_short = []
|
| 1007 |
+
for _ in range(args.gnn_layers):
|
| 1008 |
+
grus_c_short += [nn.GRUCell(args.hidden_dim, args.hidden_dim)]
|
| 1009 |
+
self.grus_c_short = nn.ModuleList(grus_c_short)
|
| 1010 |
+
|
| 1011 |
+
# 远距离 GRU
|
| 1012 |
+
# long distance GRU
|
| 1013 |
+
grus_c_long = []
|
| 1014 |
+
for _ in range(args.gnn_layers):
|
| 1015 |
+
grus_c_long += [nn.GRUCell(args.hidden_dim, args.hidden_dim)]
|
| 1016 |
+
self.grus_c_long = nn.ModuleList(grus_c_long)
|
| 1017 |
+
|
| 1018 |
+
grus_p_short = []
|
| 1019 |
+
for _ in range(args.gnn_layers):
|
| 1020 |
+
grus_p_short += [nn.GRUCell(args.hidden_dim, args.hidden_dim)]
|
| 1021 |
+
self.grus_p_short = nn.ModuleList(grus_p_short)
|
| 1022 |
+
|
| 1023 |
+
grus_p_long = []
|
| 1024 |
+
for _ in range(args.gnn_layers):
|
| 1025 |
+
grus_p_long += [nn.GRUCell(args.hidden_dim, args.hidden_dim)]
|
| 1026 |
+
self.grus_p_long = nn.ModuleList(grus_p_long)
|
| 1027 |
+
|
| 1028 |
+
#近距离全链接层
|
| 1029 |
+
#Fully Connected Layer for Short-Range Features
|
| 1030 |
+
fcs_short = []
|
| 1031 |
+
for _ in range(args.gnn_layers):
|
| 1032 |
+
fcs_short += [nn.Linear(args.hidden_dim * 2, args.hidden_dim)]
|
| 1033 |
+
self.fcs_short = nn.ModuleList(fcs_short)
|
| 1034 |
+
|
| 1035 |
+
# 远距离全连接层
|
| 1036 |
+
# Fully Connected Layer for Long-Range Features
|
| 1037 |
+
fcs_long = []
|
| 1038 |
+
for _ in range(args.gnn_layers):
|
| 1039 |
+
fcs_long += [nn.Linear(args.hidden_dim * 2, args.hidden_dim)]
|
| 1040 |
+
self.fcs_long = nn.ModuleList(fcs_long)
|
| 1041 |
+
|
| 1042 |
+
|
| 1043 |
+
self.fc1 = nn.Linear(args.emb_dim, args.hidden_dim)
|
| 1044 |
+
|
| 1045 |
+
self.nodal_att_type = args.nodal_att_type
|
| 1046 |
+
|
| 1047 |
+
in_dim = (((args.hidden_dim*2))*(args.gnn_layers + 1) + args.emb_dim)
|
| 1048 |
+
|
| 1049 |
+
# output mlp layers
|
| 1050 |
+
layers = [nn.Linear(in_dim, args.hidden_dim), nn.ReLU()]
|
| 1051 |
+
for _ in range(args.mlp_layers - 1):
|
| 1052 |
+
layers += [nn.Linear(args.hidden_dim, args.hidden_dim), nn.ReLU()]
|
| 1053 |
+
layers += [self.dropout]
|
| 1054 |
+
layers += [nn.Linear(args.hidden_dim, num_class)]
|
| 1055 |
+
|
| 1056 |
+
self.out_mlp = nn.Sequential(*layers)
|
| 1057 |
+
|
| 1058 |
+
self.attentive_node_features = attentive_node_features(in_dim)
|
| 1059 |
+
|
| 1060 |
+
self.affine1 = nn.Parameter(torch.empty(size=((args.hidden_dim) , (args.hidden_dim) )))
|
| 1061 |
+
nn.init.xavier_uniform_(self.affine1.data, gain=1.414)
|
| 1062 |
+
self.affine2 = nn.Parameter(torch.empty(size=((args.hidden_dim) , (args.hidden_dim) )))
|
| 1063 |
+
nn.init.xavier_uniform_(self.affine2.data, gain=1.414)
|
| 1064 |
+
|
| 1065 |
+
self.diff_loss = DiffLoss(args)
|
| 1066 |
+
self.beta = args.diffloss
|
| 1067 |
+
|
| 1068 |
+
def forward(self, features, adj_1, adj_2 ,s_mask,s_mask_onehot, lengths):
|
| 1069 |
+
# 检查 H1 和 H2 是否完全相等
|
| 1070 |
+
are_equal = all(torch.equal(h1, h2) for h1, h2 in zip(adj_1, adj_2))
|
| 1071 |
+
# print("adj1 和 adj2 是否完全相等:", are_equal)
|
| 1072 |
+
# print('adj1',adj_1)
|
| 1073 |
+
# print('----------------------------------------------------')
|
| 1074 |
+
|
| 1075 |
+
# print('adj2',adj_2)
|
| 1076 |
+
# print('----------------------------------------------------')
|
| 1077 |
+
|
| 1078 |
+
num_utter = features.size()[1]
|
| 1079 |
+
|
| 1080 |
+
H0 = F.relu(self.fc1(features))
|
| 1081 |
+
#print('H0', H0.size())
|
| 1082 |
+
# H0 = self.dropout(H0)
|
| 1083 |
+
H = [H0]
|
| 1084 |
+
H_combined_short_list = []
|
| 1085 |
+
#对短距离特征进行处理 Process short-range features.
|
| 1086 |
+
for l in range(self.args.gnn_layers):
|
| 1087 |
+
C = self.grus_c_short[l](H[l][:,0,:]).unsqueeze(1) #针对每一层的第一个节点,使用 GRU 单元更新节点特征并聚合信息。For the first node of each layer, use a GRU unit to update the node features and aggregate information.
|
| 1088 |
+
M = torch.zeros_like(C).squeeze(1) #初始化一个聚合信息张量 M(全零张量),并使用它与节点特征结合生成额外的特征 P。Initialize an aggregation tensor M (a zero tensor), and use it together with the node features to generate additional features P.
|
| 1089 |
+
# P = M.unsqueeze(1)
|
| 1090 |
+
P = self.grus_p_short[l](M, H[l][:,0,:]).unsqueeze(1)
|
| 1091 |
+
#H1 = F.relu(self.fcs[l](torch.cat((C,P) , dim = 2)))
|
| 1092 |
+
#H1 = F.relu(C+P)
|
| 1093 |
+
H1 = C+P
|
| 1094 |
+
for i in range(1, num_utter):
|
| 1095 |
+
# print(i,num_utter)
|
| 1096 |
+
if self.args.attn_type == 'rgcn':
|
| 1097 |
+
_, M = self.gather_short[l](H[l][:,i,:], H1, H1, adj_1[:,i,:i], s_mask[:,i,:i])
|
| 1098 |
+
# _, M = self.gather[l](H[l][:,i,:], H1, H1, adj[:,i,:i], s_mask_onehot[:,i,:i,:])
|
| 1099 |
+
else:
|
| 1100 |
+
if not self.rel_attn:
|
| 1101 |
+
_, M = self.gather_short[l](H[l][:,i,:], H1, H1, adj_1[:,i,:i])
|
| 1102 |
+
else:
|
| 1103 |
+
_, M = self.gather_short[l](H[l][:,i,:], H1, H1, adj_1[:,i,:i], s_mask[:, i, :i])
|
| 1104 |
+
|
| 1105 |
+
|
| 1106 |
+
C = self.grus_c_short[l](H[l][:,i,:], M).unsqueeze(1)
|
| 1107 |
+
|
| 1108 |
+
P = self.grus_p_short[l](M, H[l][:,i,:]).unsqueeze(1)
|
| 1109 |
+
# P = M.unsqueeze(1)
|
| 1110 |
+
#H_temp = F.relu(self.fcs[l](torch.cat((C,P) , dim = 2)))
|
| 1111 |
+
#H_temp = F.relu(C+P)
|
| 1112 |
+
H_temp = C+P#将更新后的特征 C 和额外特征 P 进行相加,生成新的节点特征 H_temp
|
| 1113 |
+
H1 = torch.cat((H1 , H_temp), dim = 1) #将当前节点的特征 H_temp 拼接到 H1 中。
|
| 1114 |
+
# print('H1', H1.size())
|
| 1115 |
+
#print('----------------------------------------------------')
|
| 1116 |
+
H.append(H1)
|
| 1117 |
+
H_combined_short_list.append(H[l+1])
|
| 1118 |
+
|
| 1119 |
+
'''
|
| 1120 |
+
下面对长距离特征进行处理 The following processes the long-distance features.
|
| 1121 |
+
'''
|
| 1122 |
+
H_long = [H0] # 初始化 H_long
|
| 1123 |
+
H_combined_long_list = [] # 存储长距离处理的结果
|
| 1124 |
+
|
| 1125 |
+
# 对长距离特征进行处理
|
| 1126 |
+
for l in range(self.args.gnn_layers):
|
| 1127 |
+
C_long = self.grus_c_long[l](H_long[l][:,0,:]).unsqueeze(1) # 使用 GRU 更新长距离的第一个节点
|
| 1128 |
+
M_long = torch.zeros_like(C_long).squeeze(1) # 初始化长距离的聚合信息张量 M_long
|
| 1129 |
+
P_long = self.grus_p_long[l](M_long, H_long[l][:,0,:]).unsqueeze(1) # 生成额外的特征 P_long
|
| 1130 |
+
|
| 1131 |
+
H1_long = C_long + P_long # 生成新的长距离节点特征 H1_long
|
| 1132 |
+
for i in range(1, num_utter):
|
| 1133 |
+
# 依据不同的 attention 类型,进行特征聚合
|
| 1134 |
+
if self.args.attn_type == 'rgcn':
|
| 1135 |
+
_, M_long = self.gather_long[l](H_long[l][:,i,:], H1_long, H1_long, adj_2[:,i,:i], s_mask[:,i,:i])
|
| 1136 |
+
else:
|
| 1137 |
+
if not self.rel_attn:
|
| 1138 |
+
_, M_long = self.gather_long[l](H_long[l][:,i,:], H1_long, H1_long, adj_2[:,i,:i])
|
| 1139 |
+
else:
|
| 1140 |
+
_, M_long = self.gather_long[l](H_long[l][:,i,:], H1_long, H1_long, adj_2[:,i,:i], s_mask[:,i,:i])
|
| 1141 |
+
|
| 1142 |
+
# 使用 GRU 更新当前节点的特征 C_long 和 M_long
|
| 1143 |
+
C_long = self.grus_c_long[l](H_long[l][:,i,:], M_long).unsqueeze(1)
|
| 1144 |
+
P_long = self.grus_p_long[l](M_long, H_long[l][:,i,:]).unsqueeze(1)
|
| 1145 |
+
|
| 1146 |
+
H_temp_long = C_long + P_long # 将更新后的特征 C_long 和 P_long 相加生成新特征
|
| 1147 |
+
H1_long = torch.cat((H1_long, H_temp_long), dim=1) # 将特征拼接到 H1_long 中
|
| 1148 |
+
H_long.append(H1_long) # 更新 H_long 列表
|
| 1149 |
+
H_combined_long_list.append(H_long[l+1])
|
| 1150 |
+
'''
|
| 1151 |
+
两个通道特征都提取完毕!Both short- and long-distance channel features have been extracted!
|
| 1152 |
+
'''
|
| 1153 |
+
# print('H_combined_short_list',H_combined_short_list)
|
| 1154 |
+
# print('H_combined_long_list',H_combined_long_list)
|
| 1155 |
+
# are_equal = all(torch.equal(h1, h2) for h1, h2 in zip(H_combined_short_list, H_combined_long_list))
|
| 1156 |
+
# print("H_combined_short_list 和 H_combined_long_list 是否完全相等:", are_equal)
|
| 1157 |
+
# for idx, tensor in enumerate(H_combined_short_list):
|
| 1158 |
+
# print(f"H_combined_short_list[{idx}] shape: {tensor.shape}")
|
| 1159 |
+
H_final = []
|
| 1160 |
+
H_0_final = torch.cat([H0, H0], dim=2)
|
| 1161 |
+
H_final.append(H_0_final)
|
| 1162 |
+
# print("H2 shape:", H2.shape)
|
| 1163 |
+
# 计算差异正则化损失
|
| 1164 |
+
diff_loss = 0
|
| 1165 |
+
for l in range(self.args.gnn_layers):
|
| 1166 |
+
# print('周期:', l)
|
| 1167 |
+
HShort_prime = H_combined_short_list[l]
|
| 1168 |
+
HLong_prime = H_combined_long_list[l]
|
| 1169 |
+
#print("HShort_prime:", HShort_prime.shape)
|
| 1170 |
+
# print("HLong_prime:", HLong_prime.shape)
|
| 1171 |
+
# print("HShort_prime shape:", HShort_prime.shape)
|
| 1172 |
+
# print("HLong_prime shape:", HLong_prime.shape)
|
| 1173 |
+
diff_loss = self.diff_loss(HShort_prime, HLong_prime) + diff_loss
|
| 1174 |
+
#print("diff_loss:", diff_loss)
|
| 1175 |
+
# print(diff_loss.item())
|
| 1176 |
+
# 互交叉注意力机制
|
| 1177 |
+
A1 = F.softmax(torch.bmm(torch.matmul(HShort_prime, self.affine1), torch.transpose(HLong_prime, 1, 2)), dim=2)
|
| 1178 |
+
A2 = F.softmax(torch.bmm(torch.matmul(HLong_prime, self.affine2), torch.transpose(HShort_prime, 1, 2)), dim=2)
|
| 1179 |
+
|
| 1180 |
+
HShort_prime_new = torch.bmm(A1, HLong_prime) # 更新的短时特征
|
| 1181 |
+
HLong_prime_new = torch.bmm(A2, HShort_prime) # 更新的长时特征
|
| 1182 |
+
|
| 1183 |
+
HShort_prime_out = self.dropout(HShort_prime_new) if l < self.args.gnn_layers - 1 else HShort_prime_new
|
| 1184 |
+
HLong_prime_out = self.dropout(HLong_prime_new) if l <self.args.gnn_layers - 1 else HLong_prime_new
|
| 1185 |
+
|
| 1186 |
+
H_layer = torch.cat([HShort_prime_out, HLong_prime_out], dim=2)
|
| 1187 |
+
H_final.append(H_layer)
|
| 1188 |
+
H_final = torch.cat(H_final, dim=2)
|
| 1189 |
+
H_final = torch.cat([H_final, features], dim=2)
|
| 1190 |
+
|
| 1191 |
+
|
| 1192 |
+
# print("H_final shape:", H_final.shape)
|
| 1193 |
+
# print("H:", H.shape)
|
| 1194 |
+
# print("H_final shape after cat:", H_final.shape)
|
| 1195 |
+
H_final = self.attentive_node_features(H_final,lengths,self.nodal_att_type)
|
| 1196 |
+
# print("H_final shape after attentive_node_features:", H_final.shape)
|
| 1197 |
+
logits = self.out_mlp(H_final)
|
| 1198 |
+
# print(diff_loss)
|
| 1199 |
+
return logits, self.beta * diff_loss
|
model_utils.py
ADDED
|
@@ -0,0 +1,507 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from torch.autograd import Variable
|
| 5 |
+
from torch.nn.utils.rnn import pad_sequence
|
| 6 |
+
import numpy as np, itertools, random, copy, math
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class DiffLoss(nn.Module):
|
| 10 |
+
|
| 11 |
+
def __init__(self, args):
|
| 12 |
+
super(DiffLoss, self).__init__()
|
| 13 |
+
|
| 14 |
+
def forward(self, input1, input2):
|
| 15 |
+
|
| 16 |
+
# input1 (B,N,D) input2 (B,N,D)
|
| 17 |
+
|
| 18 |
+
batch_size = input1.size(0)
|
| 19 |
+
N = input1.size(1)
|
| 20 |
+
input1 = input1.view(batch_size, -1) # (B,N*D)
|
| 21 |
+
input2 = input2.view(batch_size, -1) # (B, N*D)
|
| 22 |
+
# print('input1:', input1)
|
| 23 |
+
# print('input2:', input2)
|
| 24 |
+
# Zero mean
|
| 25 |
+
input1_mean = torch.mean(input1, dim=0, keepdim=True) # (1,N*D)
|
| 26 |
+
input2_mean = torch.mean(input2, dim=0, keepdim=True) # (1,N*D)
|
| 27 |
+
input1 = input1 - input1_mean # (B,N*D)
|
| 28 |
+
input2 = input2 - input2_mean # (B,N*D)
|
| 29 |
+
|
| 30 |
+
input1_l2_norm = torch.norm(input1, p=2, dim=1, keepdim=True) # (B,1)
|
| 31 |
+
input1_l2 = input1.div(input1_l2_norm.expand_as(input1) + 1e-6) # (B,N*D)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
input2_l2_norm = torch.norm(input2, p=2, dim=1, keepdim=True) # (B,1)
|
| 35 |
+
input2_l2 = input2.div(input2_l2_norm.expand_as(input2) + 1e-6) # (B,N*D)
|
| 36 |
+
# print("input1_l2_norm:", input1_l2_norm.detach().cpu().numpy())
|
| 37 |
+
# print("input2_l2_norm:", input2_l2_norm.detach().cpu().numpy())
|
| 38 |
+
# print("input1_l2:", input1_l2.detach().cpu().numpy())
|
| 39 |
+
# print("input2_l2:", input2_l2.detach().cpu().numpy())
|
| 40 |
+
norm_diff = torch.mean(torch.norm(input1_l2 - input2_l2, p=2, dim=1))
|
| 41 |
+
if norm_diff.item() == 0:
|
| 42 |
+
return torch.tensor(float('inf'), device=input1.device)
|
| 43 |
+
diff_loss = 1.0 / norm_diff
|
| 44 |
+
|
| 45 |
+
# print('loss:', diff_loss)
|
| 46 |
+
|
| 47 |
+
return diff_loss
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class MaskedNLLLoss(nn.Module):
|
| 51 |
+
|
| 52 |
+
def __init__(self, weight=None):
|
| 53 |
+
super(MaskedNLLLoss, self).__init__()
|
| 54 |
+
self.weight = weight
|
| 55 |
+
self.loss = nn.NLLLoss(weight=weight,
|
| 56 |
+
reduction='sum')
|
| 57 |
+
|
| 58 |
+
def forward(self, pred, target, mask):
|
| 59 |
+
"""
|
| 60 |
+
pred -> batch*seq_len, n_classes
|
| 61 |
+
target -> batch*seq_len
|
| 62 |
+
mask -> batch, seq_len
|
| 63 |
+
"""
|
| 64 |
+
mask_ = mask.view(-1, 1) # batch*seq_len, 1
|
| 65 |
+
if type(self.weight) == type(None):
|
| 66 |
+
loss = self.loss(pred * mask_, target) / torch.sum(mask)
|
| 67 |
+
else:
|
| 68 |
+
loss = self.loss(pred * mask_, target) \
|
| 69 |
+
/ torch.sum(self.weight[target] * mask_.squeeze())
|
| 70 |
+
return loss
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
class MaskedMSELoss(nn.Module):
|
| 74 |
+
|
| 75 |
+
def __init__(self):
|
| 76 |
+
super(MaskedMSELoss, self).__init__()
|
| 77 |
+
self.loss = nn.MSELoss(reduction='sum')
|
| 78 |
+
|
| 79 |
+
def forward(self, pred, target, mask):
|
| 80 |
+
"""
|
| 81 |
+
pred -> batch*seq_len
|
| 82 |
+
target -> batch*seq_len
|
| 83 |
+
mask -> batch*seq_len
|
| 84 |
+
"""
|
| 85 |
+
loss = self.loss(pred * mask, target) / torch.sum(mask)
|
| 86 |
+
return loss
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
class UnMaskedWeightedNLLLoss(nn.Module):
|
| 90 |
+
|
| 91 |
+
def __init__(self, weight=None):
|
| 92 |
+
super(UnMaskedWeightedNLLLoss, self).__init__()
|
| 93 |
+
self.weight = weight
|
| 94 |
+
self.loss = nn.NLLLoss(weight=weight,
|
| 95 |
+
reduction='sum')
|
| 96 |
+
|
| 97 |
+
def forward(self, pred, target):
|
| 98 |
+
"""
|
| 99 |
+
pred -> batch*seq_len, n_classes
|
| 100 |
+
target -> batch*seq_len
|
| 101 |
+
"""
|
| 102 |
+
if type(self.weight) == type(None):
|
| 103 |
+
loss = self.loss(pred, target)
|
| 104 |
+
else:
|
| 105 |
+
loss = self.loss(pred, target) \
|
| 106 |
+
/ torch.sum(self.weight[target])
|
| 107 |
+
return loss
|
| 108 |
+
|
| 109 |
+
class GatedSelection(nn.Module):
|
| 110 |
+
def __init__(self, hidden_size):
|
| 111 |
+
super().__init__()
|
| 112 |
+
self.context_trans = nn.Linear(hidden_size, hidden_size)
|
| 113 |
+
self.linear1 = nn.Linear(hidden_size, hidden_size)
|
| 114 |
+
self.linear2 = nn.Linear(hidden_size, hidden_size)
|
| 115 |
+
self.fc = nn.Linear(hidden_size, hidden_size)
|
| 116 |
+
self.sigmoid = nn.Sigmoid()
|
| 117 |
+
self.relu = nn.ReLU()
|
| 118 |
+
|
| 119 |
+
def forward(self, x1, x2):
|
| 120 |
+
x2 = self.context_trans(x2)
|
| 121 |
+
s = self.sigmoid(self.linear1(x1)+self.linear2(x2))
|
| 122 |
+
h = s * x1 + (1 - s) * x2
|
| 123 |
+
return self.relu(self.fc(h))
|
| 124 |
+
|
| 125 |
+
def mask_logic(alpha, adj):
|
| 126 |
+
'''
|
| 127 |
+
performing mask logic with adj
|
| 128 |
+
:param alpha:
|
| 129 |
+
:param adj:
|
| 130 |
+
:return:
|
| 131 |
+
'''
|
| 132 |
+
return alpha - (1 - adj) * 1e30
|
| 133 |
+
|
| 134 |
+
class GatLinear(nn.Module):
|
| 135 |
+
def __init__(self, hidden_size):
|
| 136 |
+
super().__init__()
|
| 137 |
+
self.linear = nn.Linear(hidden_size * 2, 1)
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def forward(self, Q, K, V, adj):
|
| 141 |
+
'''
|
| 142 |
+
imformation gatherer with linear attention
|
| 143 |
+
:param Q: (B, D) # query utterance
|
| 144 |
+
:param K: (B, N, D) # context
|
| 145 |
+
:param V: (B, N, D) # context
|
| 146 |
+
:param adj: (B, N) # the adj matrix of the i th node
|
| 147 |
+
:return:
|
| 148 |
+
'''
|
| 149 |
+
N = K.size()[1]
|
| 150 |
+
# print('Q',Q.size())
|
| 151 |
+
Q = Q.unsqueeze(1).expand(-1, N, -1) # (B, N, D)
|
| 152 |
+
# print('K',K.size())
|
| 153 |
+
X = torch.cat((Q,K), dim = 2) # (B, N, 2D)
|
| 154 |
+
# print('X',X.size())
|
| 155 |
+
alpha = self.linear(X).permute(0,2,1) #(B, 1, N)
|
| 156 |
+
# print('alpha',alpha.size())
|
| 157 |
+
# print(alpha)
|
| 158 |
+
adj = adj.unsqueeze(1)
|
| 159 |
+
alpha = mask_logic(alpha, adj) # (B, 1, N)
|
| 160 |
+
# print('alpha after mask',alpha.size())
|
| 161 |
+
# print(alpha)
|
| 162 |
+
|
| 163 |
+
attn_weight = F.softmax(alpha, dim = 2) # (B, 1, N)
|
| 164 |
+
# print('attn_weight',attn_weight.size())
|
| 165 |
+
# print(attn_weight)
|
| 166 |
+
|
| 167 |
+
attn_sum = torch.bmm(attn_weight, V).squeeze(1) # (B, D)
|
| 168 |
+
# print('attn_sum',attn_sum.size())
|
| 169 |
+
|
| 170 |
+
return attn_weight, attn_sum
|
| 171 |
+
|
| 172 |
+
class GatDot(nn.Module):
|
| 173 |
+
def __init__(self, hidden_size):
|
| 174 |
+
super().__init__()
|
| 175 |
+
self.linear1 = nn.Linear(hidden_size, hidden_size)
|
| 176 |
+
self.linear2 = nn.Linear(hidden_size, hidden_size)
|
| 177 |
+
|
| 178 |
+
def forward(self, Q, K, V, adj):
|
| 179 |
+
'''
|
| 180 |
+
imformation gatherer with dot product attention
|
| 181 |
+
:param Q: (B, D) # query utterance
|
| 182 |
+
:param K: (B, N, D) # context
|
| 183 |
+
:param V: (B, N, D) # context
|
| 184 |
+
:param adj: (B, N) # the adj matrix of the i th node
|
| 185 |
+
:return:
|
| 186 |
+
'''
|
| 187 |
+
N = K.size()[1]
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
Q = self.linear1(Q).unsqueeze(2) # (B,D,1)
|
| 191 |
+
# K = self.linear2(Q) # (B, N, D)
|
| 192 |
+
K = self.linear2(K) # (B, N, D)
|
| 193 |
+
|
| 194 |
+
alpha = torch.bmm(K, Q).permute(0, 2, 1) # (B, 1, N)
|
| 195 |
+
|
| 196 |
+
adj = adj.unsqueeze(1)
|
| 197 |
+
alpha = mask_logic(alpha, adj) # (B, 1, N)
|
| 198 |
+
|
| 199 |
+
attn_weight = F.softmax(alpha, dim=2) # (B, 1, N)
|
| 200 |
+
|
| 201 |
+
attn_sum = torch.bmm(attn_weight, V).squeeze(1) # (B, D)
|
| 202 |
+
|
| 203 |
+
return attn_weight, attn_sum
|
| 204 |
+
|
| 205 |
+
class GatLinear_rel(nn.Module):
|
| 206 |
+
def __init__(self, hidden_size):
|
| 207 |
+
super().__init__()
|
| 208 |
+
self.linear = nn.Linear(hidden_size * 3, 1)
|
| 209 |
+
self.rel_emb = nn.Embedding(2, hidden_size)
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
def forward(self, Q, K, V, adj, s_mask):
|
| 213 |
+
'''
|
| 214 |
+
imformation gatherer with linear attention
|
| 215 |
+
:param Q: (B, D) # query utterance
|
| 216 |
+
:param K: (B, N, D) # context
|
| 217 |
+
:param V: (B, N, D) # context
|
| 218 |
+
:param adj: (B, N) # the adj matrix of the i th node
|
| 219 |
+
:param s_mask: (B, N) #
|
| 220 |
+
:return:
|
| 221 |
+
'''
|
| 222 |
+
rel_emb = self.rel_emb(s_mask) # (B, N, D)
|
| 223 |
+
N = K.size()[1]
|
| 224 |
+
# print('Q',Q.size())
|
| 225 |
+
Q = Q.unsqueeze(1).expand(-1, N, -1) # (B, N, D)
|
| 226 |
+
# print('K',K.size())
|
| 227 |
+
# print('rel_emb', rel_emb.size())
|
| 228 |
+
X = torch.cat((Q,K, rel_emb), dim = 2) # (B, N, 2D)? (B, N, 3D)
|
| 229 |
+
# print('X',X.size())
|
| 230 |
+
alpha = self.linear(X).permute(0,2,1) #(B, 1, N)
|
| 231 |
+
# print('alpha',alpha.size())
|
| 232 |
+
# print(alpha)
|
| 233 |
+
adj = adj.unsqueeze(1)
|
| 234 |
+
alpha = mask_logic(alpha, adj) # (B, 1, N)
|
| 235 |
+
# print('alpha after mask',alpha.size())
|
| 236 |
+
# print(alpha)
|
| 237 |
+
|
| 238 |
+
attn_weight = F.softmax(alpha, dim = 2) # (B, 1, N)
|
| 239 |
+
# print('attn_weight',attn_weight.size())
|
| 240 |
+
# print(attn_weight)
|
| 241 |
+
|
| 242 |
+
attn_sum = torch.bmm(attn_weight, V).squeeze(1) # (B, D)
|
| 243 |
+
# print('attn_sum',attn_sum.size())
|
| 244 |
+
|
| 245 |
+
return attn_weight, attn_sum
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
class GatDot_rel(nn.Module):
|
| 249 |
+
def __init__(self, hidden_size):
|
| 250 |
+
super().__init__()
|
| 251 |
+
self.linear1 = nn.Linear(hidden_size, hidden_size)
|
| 252 |
+
self.linear2 = nn.Linear(hidden_size, hidden_size)
|
| 253 |
+
self.linear3 = nn.Linear(hidden_size, 1)
|
| 254 |
+
self.rel_emb = nn.Embedding(2, hidden_size)
|
| 255 |
+
|
| 256 |
+
def forward(self, Q, K, V, adj, s_mask):
|
| 257 |
+
'''
|
| 258 |
+
imformation gatherer with dot product attention
|
| 259 |
+
:param Q: (B, D) # query utterance
|
| 260 |
+
:param K: (B, N, D) # context
|
| 261 |
+
:param V: (B, N, D) # context
|
| 262 |
+
:param adj: (B, N) # the adj matrix of the i th node
|
| 263 |
+
:param s_mask: (B, N) # relation mask
|
| 264 |
+
:return:
|
| 265 |
+
'''
|
| 266 |
+
N = K.size()[1]
|
| 267 |
+
|
| 268 |
+
rel_emb = self.rel_emb(s_mask)
|
| 269 |
+
Q = self.linear1(Q).unsqueeze(2) # (B,D,1)
|
| 270 |
+
K = self.linear2(K) # (B, N, D)
|
| 271 |
+
y = self.linear3(rel_emb) # (B, N, 1)
|
| 272 |
+
|
| 273 |
+
alpha = (torch.bmm(K, Q) + y).permute(0, 2, 1) # (B, 1, N)
|
| 274 |
+
|
| 275 |
+
adj = adj.unsqueeze(1)
|
| 276 |
+
alpha = mask_logic(alpha, adj) # (B, 1, N)
|
| 277 |
+
|
| 278 |
+
attn_weight = F.softmax(alpha, dim=2) # (B, 1, N)
|
| 279 |
+
|
| 280 |
+
attn_sum = torch.bmm(attn_weight, V).squeeze(1) # (B, D)
|
| 281 |
+
|
| 282 |
+
return attn_weight, attn_sum
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
class GAT_dialoggcn(nn.Module):
|
| 286 |
+
'''
|
| 287 |
+
H_i = alpha_ij(W_rH_j)
|
| 288 |
+
alpha_ij = attention(H_i, H_j)
|
| 289 |
+
'''
|
| 290 |
+
def __init__(self, hidden_size):
|
| 291 |
+
super().__init__()
|
| 292 |
+
self.hidden_size = hidden_size
|
| 293 |
+
self.linear = nn.Linear(hidden_size * 2, 1)
|
| 294 |
+
self.rel_emb = nn.Parameter(torch.randn(2, hidden_size, hidden_size))
|
| 295 |
+
|
| 296 |
+
def forward(self, Q, K, V, adj, s_mask_onehot):
|
| 297 |
+
'''
|
| 298 |
+
imformation gatherer with linear attention
|
| 299 |
+
:param Q: (B, D) # query utterance
|
| 300 |
+
:param K: (B, N, D) # context
|
| 301 |
+
:param V: (B, N, D) # context
|
| 302 |
+
:param adj: (B, N) # the adj matrix of the i th node
|
| 303 |
+
:param s_mask: (B, N, 2) #
|
| 304 |
+
:return:
|
| 305 |
+
'''
|
| 306 |
+
B = K.size()[0]
|
| 307 |
+
N = K.size()[1]
|
| 308 |
+
# print('Q',Q.size())
|
| 309 |
+
Q = Q.unsqueeze(1).expand(-1, N, -1) # (B, N, D);
|
| 310 |
+
# print('K',K.size())
|
| 311 |
+
X = torch.cat((Q,K), dim = 2) # (B, N, 2D)
|
| 312 |
+
# print('X',X.size())
|
| 313 |
+
alpha = self.linear(X).permute(0,2,1) #(B, 1, N)
|
| 314 |
+
# print('alpha',alpha.size())
|
| 315 |
+
# print(alpha)
|
| 316 |
+
adj = adj.unsqueeze(1)
|
| 317 |
+
alpha = mask_logic(alpha, adj) # (B, 1, N)
|
| 318 |
+
# print('alpha after mask',alpha.size())
|
| 319 |
+
# print(alpha)
|
| 320 |
+
|
| 321 |
+
attn_weight = F.softmax(alpha, dim = 2) # (B, 1, N)
|
| 322 |
+
# print('attn_weight',attn_weight.size())
|
| 323 |
+
# print(attn_weight)
|
| 324 |
+
|
| 325 |
+
# print('s_mask_onehot', s_mask_onehot.size())
|
| 326 |
+
D = self.rel_emb.size()[2]
|
| 327 |
+
# print('rel_emb', self.rel_emb.size())
|
| 328 |
+
rel_emb = self.rel_emb.unsqueeze(0).expand(B,-1,-1,-1)
|
| 329 |
+
# rel_emb = self.rel_emb.unsqueeze(0).repeat(B, 1, 1, 1)
|
| 330 |
+
# print('rel_emb expand', rel_emb.size())
|
| 331 |
+
|
| 332 |
+
rel_emb = rel_emb.reshape((B, 2, D*D))
|
| 333 |
+
# print('rel_emb resize', rel_emb.size())
|
| 334 |
+
Wr = torch.bmm(s_mask_onehot, rel_emb).reshape((B, N, D, D)) # (B, N, D, D)
|
| 335 |
+
# print('Wr', Wr.size()) # (B, N, D, D)
|
| 336 |
+
|
| 337 |
+
Wr = Wr.reshape((B*N, D, D))
|
| 338 |
+
# print('Wr after reshape', Wr.size())
|
| 339 |
+
|
| 340 |
+
V = V.unsqueeze(2).reshape((B*N, 1, -1)) # (B*N, 1, D)
|
| 341 |
+
# print('V after reshape', V.size())
|
| 342 |
+
V = torch.bmm(V, Wr).unsqueeze(1) #(B * N, D)
|
| 343 |
+
# print('V after transform', V.size())
|
| 344 |
+
V = V.reshape((B,N,-1))
|
| 345 |
+
# print('Final V', V.size())
|
| 346 |
+
|
| 347 |
+
attn_sum = torch.bmm(attn_weight, V).squeeze(1) # (B, D)
|
| 348 |
+
# print('attn_sum',attn_sum.size())
|
| 349 |
+
|
| 350 |
+
return attn_weight, attn_sum
|
| 351 |
+
|
| 352 |
+
|
| 353 |
+
class GAT_dialoggcn_v1(nn.Module):
|
| 354 |
+
'''
|
| 355 |
+
use linear to avoid OOM
|
| 356 |
+
H_i = alpha_ij(W_rH_j)
|
| 357 |
+
alpha_ij = attention(H_i, H_j)
|
| 358 |
+
'''
|
| 359 |
+
def __init__(self, hidden_size):
|
| 360 |
+
super().__init__()
|
| 361 |
+
self.hidden_size = hidden_size
|
| 362 |
+
self.linear = nn.Linear(hidden_size * 2, 1)
|
| 363 |
+
self.Wr0 = nn.Linear(hidden_size, hidden_size, bias = False)
|
| 364 |
+
self.Wr1 = nn.Linear(hidden_size, hidden_size, bias = False)
|
| 365 |
+
|
| 366 |
+
def forward(self, Q, K, V, adj, s_mask):
|
| 367 |
+
'''
|
| 368 |
+
imformation gatherer with linear attention
|
| 369 |
+
:param Q: (B, D) # query utterance
|
| 370 |
+
:param K: (B, N, D) # context
|
| 371 |
+
:param V: (B, N, D) # context
|
| 372 |
+
:param adj: (B, N) # the adj matrix of the i th node
|
| 373 |
+
:param s_mask: (B, N) #
|
| 374 |
+
:return:
|
| 375 |
+
'''
|
| 376 |
+
B = K.size()[0]
|
| 377 |
+
N = K.size()[1]
|
| 378 |
+
# print('Q',Q.size())
|
| 379 |
+
Q = Q.unsqueeze(1).expand(-1, N, -1) # (B, N, D);
|
| 380 |
+
# print('K',K.size())
|
| 381 |
+
X = torch.cat((Q,K), dim = 2) # (B, N, 2D)
|
| 382 |
+
# print('X',X.size())
|
| 383 |
+
alpha = self.linear(X).permute(0,2,1) #(B, 1, N)
|
| 384 |
+
#alpha = F.leaky_relu(alpha)
|
| 385 |
+
# print('alpha',alpha.size())
|
| 386 |
+
# print(alpha)
|
| 387 |
+
adj = adj.unsqueeze(1) # (B, 1, N)
|
| 388 |
+
alpha = mask_logic(alpha, adj) # (B, 1, N)
|
| 389 |
+
# print('alpha after mask',alpha.size())
|
| 390 |
+
# print(alpha)
|
| 391 |
+
|
| 392 |
+
attn_weight = F.softmax(alpha, dim = 2) # (B, 1, N)
|
| 393 |
+
# print('attn_weight',attn_weight.size())
|
| 394 |
+
# print(attn_weight)
|
| 395 |
+
|
| 396 |
+
V0 = self.Wr0(V) # (B, N, D)
|
| 397 |
+
V1 = self.Wr1(V) # (B, N, D)
|
| 398 |
+
|
| 399 |
+
s_mask = s_mask.unsqueeze(2).float() # (B, N, 1)
|
| 400 |
+
V = V0 * s_mask + V1 * (1 - s_mask)
|
| 401 |
+
|
| 402 |
+
attn_sum = torch.bmm(attn_weight, V).squeeze(1) # (B, D)
|
| 403 |
+
# print('attn_sum',attn_sum.size())
|
| 404 |
+
|
| 405 |
+
return attn_weight, attn_sum
|
| 406 |
+
|
| 407 |
+
|
| 408 |
+
class GAT_dialoggcn_v2(nn.Module):
|
| 409 |
+
'''
|
| 410 |
+
use linear to avoid OOM
|
| 411 |
+
H_i = alpha_ij(W_rH_j)
|
| 412 |
+
alpha_ij = attention(H_i, H_j, rel)
|
| 413 |
+
'''
|
| 414 |
+
def __init__(self, hidden_size):
|
| 415 |
+
super().__init__()
|
| 416 |
+
self.hidden_size = hidden_size
|
| 417 |
+
self.linear = nn.Linear(hidden_size * 3, 1)
|
| 418 |
+
self.Wr0 = nn.Linear(hidden_size, hidden_size, bias = False)
|
| 419 |
+
self.Wr1 = nn.Linear(hidden_size, hidden_size, bias = False)
|
| 420 |
+
self.rel_emb = nn.Embedding(2, hidden_size)
|
| 421 |
+
|
| 422 |
+
def forward(self, Q, K, V, adj, s_mask):
|
| 423 |
+
'''
|
| 424 |
+
imformation gatherer with linear attention
|
| 425 |
+
:param Q: (B, D) # query utterance
|
| 426 |
+
:param K: (B, N, D) # context
|
| 427 |
+
:param V: (B, N, D) # context
|
| 428 |
+
:param adj: (B, N) # the adj matrix of the i th node
|
| 429 |
+
:param s_mask: (B, N) #
|
| 430 |
+
:return:
|
| 431 |
+
'''
|
| 432 |
+
rel_emb = self.rel_emb(s_mask) # (B, N, D)
|
| 433 |
+
B = K.size()[0]
|
| 434 |
+
N = K.size()[1]
|
| 435 |
+
# print('Q',Q.size())
|
| 436 |
+
Q = Q.unsqueeze(1).expand(-1, N, -1) # (B, N, D);
|
| 437 |
+
# print('K',K.size())
|
| 438 |
+
X = torch.cat((Q,K,rel_emb), dim = 2) # (B, N, 3D)
|
| 439 |
+
# print('X',X.size())
|
| 440 |
+
alpha = self.linear(X).permute(0,2,1) #(B, 1, N)
|
| 441 |
+
# print('alpha',alpha.size())
|
| 442 |
+
# print(alpha)
|
| 443 |
+
adj = adj.unsqueeze(1)
|
| 444 |
+
alpha = mask_logic(alpha, adj) # (B, 1, N)
|
| 445 |
+
# print('alpha after mask',alpha.size())
|
| 446 |
+
# print(alpha)
|
| 447 |
+
|
| 448 |
+
attn_weight = F.softmax(alpha, dim = 2) # (B, 1, N)
|
| 449 |
+
# print('attn_weight',attn_weight.size())
|
| 450 |
+
# print(attn_weight)
|
| 451 |
+
|
| 452 |
+
V0 = self.Wr0(V) # (B, N,D)
|
| 453 |
+
V1 = self.Wr1(V) # (B, N, D)
|
| 454 |
+
|
| 455 |
+
s_mask = s_mask.unsqueeze(2).float()
|
| 456 |
+
V = V0 * s_mask + V1 * (1 - s_mask)
|
| 457 |
+
|
| 458 |
+
attn_sum = torch.bmm(attn_weight, V).squeeze(1) # (B, D)
|
| 459 |
+
# print('attn_sum',attn_sum.size())
|
| 460 |
+
|
| 461 |
+
return attn_weight, attn_sum
|
| 462 |
+
|
| 463 |
+
|
| 464 |
+
class attentive_node_features(nn.Module):
|
| 465 |
+
'''
|
| 466 |
+
Method to obtain attentive node features over the graph convoluted features
|
| 467 |
+
'''
|
| 468 |
+
def __init__(self, hidden_size):
|
| 469 |
+
super().__init__()
|
| 470 |
+
self.transform = nn.Linear(hidden_size, hidden_size)
|
| 471 |
+
|
| 472 |
+
def forward(self,features, lengths, nodal_att_type):
|
| 473 |
+
'''
|
| 474 |
+
features : (B, N, V)
|
| 475 |
+
lengths : (B, )
|
| 476 |
+
nodal_att_type : type of the final nodal attention
|
| 477 |
+
'''
|
| 478 |
+
|
| 479 |
+
if nodal_att_type==None:
|
| 480 |
+
return features
|
| 481 |
+
|
| 482 |
+
batch_size = features.size(0)
|
| 483 |
+
max_seq_len = features.size(1)
|
| 484 |
+
padding_mask = [l*[1]+(max_seq_len-l)*[0] for l in lengths]
|
| 485 |
+
padding_mask = torch.tensor(padding_mask).to(features) # (B, N)
|
| 486 |
+
causal_mask = torch.ones(max_seq_len, max_seq_len).to(features) # (N, N)
|
| 487 |
+
causal_mask = torch.tril(causal_mask).unsqueeze(0) # (1, N, N)
|
| 488 |
+
|
| 489 |
+
if nodal_att_type=='global':
|
| 490 |
+
mask = padding_mask.unsqueeze(1)
|
| 491 |
+
elif nodal_att_type=='past':
|
| 492 |
+
mask = padding_mask.unsqueeze(1)*causal_mask
|
| 493 |
+
|
| 494 |
+
x = self.transform(features) # (B, N, V)
|
| 495 |
+
temp = torch.bmm(x, features.permute(0,2,1))
|
| 496 |
+
#print(temp)
|
| 497 |
+
alpha = F.softmax(torch.tanh(temp), dim=2) # (B, N, N)
|
| 498 |
+
alpha_masked = alpha*mask # (B, N, N)
|
| 499 |
+
|
| 500 |
+
alpha_sum = torch.sum(alpha_masked, dim=2, keepdim=True) # (B, N, 1)
|
| 501 |
+
#print(alpha_sum)
|
| 502 |
+
alpha = alpha_masked / alpha_sum # (B, N, N)
|
| 503 |
+
attn_pool = torch.bmm(alpha, features) # (B, N, V)
|
| 504 |
+
|
| 505 |
+
return attn_pool
|
| 506 |
+
|
| 507 |
+
|
requirements.txt
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch==2.0.0+cu117
|
| 2 |
+
transformers==4.46.3
|
| 3 |
+
numpy==1.24.2
|
| 4 |
+
pandas==2.1.4
|
| 5 |
+
matplotlib==3.7.1
|
| 6 |
+
scikit-learn==1.2.2
|
| 7 |
+
tqdm==4.67.1
|
run.py
ADDED
|
@@ -0,0 +1,239 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = '1'
|
| 3 |
+
import numpy as np, argparse, time, pickle, random
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import torch.optim as optim
|
| 7 |
+
from dataloader import IEMOCAPDataset, get_train_loader
|
| 8 |
+
from model import *
|
| 9 |
+
from sklearn.metrics import f1_score, confusion_matrix, accuracy_score, classification_report, \
|
| 10 |
+
precision_recall_fscore_support
|
| 11 |
+
from trainer import train_or_eval_model, save_badcase
|
| 12 |
+
from dataset import IEMOCAPDataset
|
| 13 |
+
from dataloader import get_IEMOCAP_loaders
|
| 14 |
+
from transformers import AdamW
|
| 15 |
+
import copy
|
| 16 |
+
|
| 17 |
+
# We use seed = 100 for reproduction of the results reported in the paper.
|
| 18 |
+
seed = 100
|
| 19 |
+
|
| 20 |
+
import logging
|
| 21 |
+
|
| 22 |
+
def get_logger(filename, verbosity=1, name=None):
|
| 23 |
+
level_dict = {0: logging.DEBUG, 1: logging.INFO, 2: logging.WARNING}
|
| 24 |
+
formatter = logging.Formatter(
|
| 25 |
+
"[%(asctime)s][%(filename)s][line:%(lineno)d][%(levelname)s] %(message)s"
|
| 26 |
+
)
|
| 27 |
+
logger = logging.getLogger(name)
|
| 28 |
+
logger.setLevel(level_dict[verbosity])
|
| 29 |
+
|
| 30 |
+
fh = logging.FileHandler(filename, "w")
|
| 31 |
+
fh.setFormatter(formatter)
|
| 32 |
+
logger.addHandler(fh)
|
| 33 |
+
|
| 34 |
+
sh = logging.StreamHandler()
|
| 35 |
+
sh.setFormatter(formatter)
|
| 36 |
+
logger.addHandler(sh)
|
| 37 |
+
|
| 38 |
+
return logger
|
| 39 |
+
|
| 40 |
+
def seed_everything(seed=seed):
|
| 41 |
+
random.seed(seed)
|
| 42 |
+
np.random.seed(seed)
|
| 43 |
+
torch.manual_seed(seed)
|
| 44 |
+
torch.cuda.manual_seed(seed)
|
| 45 |
+
torch.cuda.manual_seed_all(seed)
|
| 46 |
+
torch.backends.cudnn.benchmark = False
|
| 47 |
+
torch.backends.cudnn.deterministic = True
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
if __name__ == '__main__':
|
| 51 |
+
|
| 52 |
+
path = './saved_models/'
|
| 53 |
+
|
| 54 |
+
parser = argparse.ArgumentParser()
|
| 55 |
+
parser.add_argument('--bert_model_dir', type=str, default='')
|
| 56 |
+
parser.add_argument('--bert_tokenizer_dir', type=str, default='')
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
parser.add_argument('--bert_dim', type = int, default=1024)
|
| 60 |
+
parser.add_argument('--hidden_dim', type = int, default=300)
|
| 61 |
+
parser.add_argument('--mlp_layers', type=int, default=2, help='Number of output mlp layers.')
|
| 62 |
+
parser.add_argument('--gnn_layers', type=int, default=2, help='Number of gnn layers.')
|
| 63 |
+
parser.add_argument('--emb_dim', type=int, default=1024, help='Feature size.')
|
| 64 |
+
|
| 65 |
+
parser.add_argument('--attn_type', type=str, default='rgcn', choices=['dotprod','linear','bilinear', 'rgcn'], help='Feature size.')
|
| 66 |
+
parser.add_argument('--no_rel_attn', action='store_true', default=False, help='no relation for edges' )
|
| 67 |
+
|
| 68 |
+
parser.add_argument('--max_sent_len', type=int, default=200,
|
| 69 |
+
help='max content length for each text, if set to 0, then the max length has no constrain')
|
| 70 |
+
|
| 71 |
+
parser.add_argument('--no_cuda', action='store_true', default=False, help='does not use GPU')
|
| 72 |
+
|
| 73 |
+
parser.add_argument('--dataset_name', default='IEMOCAP', type= str, help='dataset name, IEMOCAP or MELD or DailyDialog')
|
| 74 |
+
|
| 75 |
+
parser.add_argument('--windowps', type=int, default=1,
|
| 76 |
+
help='context window size for constructing edges in graph model for past utterances for short')
|
| 77 |
+
parser.add_argument('--windowpl', type=int, default=5,
|
| 78 |
+
help='context window size for constructing edges in graph model for past utterances for long')
|
| 79 |
+
|
| 80 |
+
parser.add_argument('--windowf', type=int, default=0,
|
| 81 |
+
help='context window size for constructing edges in graph model for future utterances')
|
| 82 |
+
|
| 83 |
+
parser.add_argument('--max_grad_norm', type=float, default=5.0, help='Gradient clipping.')
|
| 84 |
+
|
| 85 |
+
parser.add_argument('--lr', type=float, default=1e-3, metavar='LR', help='learning rate')
|
| 86 |
+
|
| 87 |
+
parser.add_argument('--dropout', type=float, default=0, metavar='dropout', help='dropout rate')
|
| 88 |
+
|
| 89 |
+
parser.add_argument('--batch_size', type=int, default=16, metavar='BS', help='batch size')
|
| 90 |
+
|
| 91 |
+
parser.add_argument('--epochs', type=int, default=20, metavar='E', help='number of epochs')
|
| 92 |
+
|
| 93 |
+
parser.add_argument('--tensorboard', action='store_true', default=False, help='Enables tensorboard log')
|
| 94 |
+
|
| 95 |
+
parser.add_argument('--nodal_att_type', type=str, default=None, choices=['global','past'], help='type of nodal attention')
|
| 96 |
+
|
| 97 |
+
parser.add_argument('--curriculum', action='store_true', default=False, help='Enables curriculum learning')
|
| 98 |
+
|
| 99 |
+
parser.add_argument('--bucket_number', type=int, default=0)
|
| 100 |
+
|
| 101 |
+
parser.add_argument('--max_epoch_per_baby_step', type=int, default=0)
|
| 102 |
+
|
| 103 |
+
parser.add_argument('--diffloss', type=float , default=0.1, help='diffloss beta')
|
| 104 |
+
|
| 105 |
+
args = parser.parse_args()
|
| 106 |
+
print(args)
|
| 107 |
+
|
| 108 |
+
seed_everything()
|
| 109 |
+
|
| 110 |
+
args.cuda = torch.cuda.is_available() and not args.no_cuda
|
| 111 |
+
|
| 112 |
+
if args.cuda:
|
| 113 |
+
print('Running on GPU')
|
| 114 |
+
else:
|
| 115 |
+
print('Running on CPU')
|
| 116 |
+
|
| 117 |
+
if args.tensorboard:
|
| 118 |
+
from tensorboardX import SummaryWriter
|
| 119 |
+
|
| 120 |
+
writer = SummaryWriter()
|
| 121 |
+
|
| 122 |
+
logger = get_logger(path + args.dataset_name + '/logging.log')
|
| 123 |
+
logger.info('start training on GPU {}!'.format(os.environ["CUDA_VISIBLE_DEVICES"]))
|
| 124 |
+
logger.info(args)
|
| 125 |
+
|
| 126 |
+
cuda = args.cuda
|
| 127 |
+
n_epochs = args.epochs
|
| 128 |
+
batch_size = args.batch_size
|
| 129 |
+
valid_loader, test_loader, speaker_vocab, label_vocab, person_vec = get_IEMOCAP_loaders(dataset_name=args.dataset_name, batch_size=batch_size, num_workers=0, args = args)
|
| 130 |
+
n_classes = len(label_vocab['itos'])
|
| 131 |
+
|
| 132 |
+
print('building model..')
|
| 133 |
+
model = DAGERC_new_4(args, n_classes)
|
| 134 |
+
if args.dataset_name == 'IEMOCAP':
|
| 135 |
+
class_labels = ['excitement', 'neutral', 'frustration', 'sadness', 'happiness', 'anger']
|
| 136 |
+
else:
|
| 137 |
+
class_labels = ['Neutral', 'Surprise', 'Fear', 'Sadness', 'Joy', 'Disgust', 'Anger']
|
| 138 |
+
|
| 139 |
+
if torch.cuda.device_count() > 1:
|
| 140 |
+
print('Multi-GPU...........')
|
| 141 |
+
model = nn.DataParallel(model,device_ids = range(torch.cuda.device_count()))
|
| 142 |
+
if cuda:
|
| 143 |
+
model.cuda()
|
| 144 |
+
|
| 145 |
+
loss_function = nn.CrossEntropyLoss(ignore_index=-1)
|
| 146 |
+
optimizer = AdamW(model.parameters() , lr=args.lr)
|
| 147 |
+
|
| 148 |
+
best_fscore,best_acc, best_loss, best_label, best_pred, best_mask = None,None, None, None, None, None
|
| 149 |
+
all_fscore, all_acc, all_loss = [], [], []
|
| 150 |
+
best_acc = 0.
|
| 151 |
+
best_fscore = 0.
|
| 152 |
+
best_epoch = 0
|
| 153 |
+
best_model = None
|
| 154 |
+
for e in range(n_epochs):
|
| 155 |
+
start_time = time.time()
|
| 156 |
+
#for curiculum learning
|
| 157 |
+
if e + 1 < args.bucket_number:
|
| 158 |
+
train_loader = get_train_loader(dataset_name=args.dataset_name, batch_size=batch_size, num_workers=0,
|
| 159 |
+
args=args, babystep_index=e + 1)
|
| 160 |
+
else:
|
| 161 |
+
train_loader = get_train_loader(dataset_name=args.dataset_name, batch_size=batch_size, num_workers=0,
|
| 162 |
+
args=args, babystep_index=args.bucket_number)
|
| 163 |
+
if args.dataset_name == 'DailyDialog':
|
| 164 |
+
train_loss, train_acc, _, _, train_micro_fscore, train_macro_fscore = train_or_eval_model(model,
|
| 165 |
+
loss_function,
|
| 166 |
+
train_loader, e,
|
| 167 |
+
cuda,
|
| 168 |
+
args, optimizer,
|
| 169 |
+
True)
|
| 170 |
+
valid_loss, valid_acc, _, _, valid_micro_fscore, valid_macro_fscore = train_or_eval_model(model,
|
| 171 |
+
loss_function,
|
| 172 |
+
valid_loader, e,
|
| 173 |
+
cuda, args)
|
| 174 |
+
test_loss, test_acc, test_label, test_pred, test_micro_fscore, test_macro_fscore = train_or_eval_model(
|
| 175 |
+
model, loss_function, test_loader, e, cuda, args)
|
| 176 |
+
|
| 177 |
+
all_fscore.append([valid_micro_fscore, test_micro_fscore, valid_macro_fscore, test_macro_fscore])
|
| 178 |
+
|
| 179 |
+
logger.info( 'Epoch: {}, train_loss: {}, train_acc: {}, train_micro_fscore: {}, train_macro_fscore: {}, valid_loss: {}, valid_acc: {}, valid_micro_fscore: {}, valid_macro_fscore: {}, test_loss: {}, test_acc: {}, test_micro_fscore: {}, test_macro_fscore: {}, time: {} sec'. \
|
| 180 |
+
format(e + 1, train_loss, train_acc, train_micro_fscore, train_macro_fscore, valid_loss, valid_acc, valid_micro_fscore, valid_macro_fscore, test_loss, test_acc,
|
| 181 |
+
test_micro_fscore, test_macro_fscore, round(time.time() - start_time, 2)))
|
| 182 |
+
|
| 183 |
+
else:
|
| 184 |
+
train_loss, train_acc, _, _, train_fscore, _ , _ = train_or_eval_model(model, loss_function,
|
| 185 |
+
train_loader, e, cuda,
|
| 186 |
+
args, optimizer, True)
|
| 187 |
+
valid_loss, valid_acc, _, _, valid_fscore, _ , _= train_or_eval_model(model, loss_function,
|
| 188 |
+
valid_loader, e, cuda, args)
|
| 189 |
+
test_loss, test_acc, test_label, test_pred, test_fscore, test_f1_per_class, avg_macro_fscore= train_or_eval_model(model,loss_function, test_loader, e, cuda, args)
|
| 190 |
+
|
| 191 |
+
all_fscore.append([valid_fscore, test_fscore])
|
| 192 |
+
|
| 193 |
+
logger.info(
|
| 194 |
+
'Epoch: {}, train_loss: {}, train_acc: {}, train_fscore: {}, valid_loss: {}, valid_acc: {}, valid_fscore: {}, test_loss: {}, test_acc: {}, test_fscore: {}, avg_macro_fscore: {}, time: {} sec'. \
|
| 195 |
+
format(e + 1, train_loss, train_acc, train_fscore, valid_loss, valid_acc, valid_fscore, test_loss,
|
| 196 |
+
test_acc,
|
| 197 |
+
test_fscore, avg_macro_fscore, round(time.time() - start_time, 2)))
|
| 198 |
+
|
| 199 |
+
f1_with_labels = {label: f1 for label, f1 in zip(class_labels, test_f1_per_class)}
|
| 200 |
+
|
| 201 |
+
logger.info(f"Test F1 per class: {f1_with_labels}")
|
| 202 |
+
|
| 203 |
+
if (test_fscore > best_fscore):
|
| 204 |
+
best_fscore = test_fscore
|
| 205 |
+
best_model = copy.deepcopy(model.state_dict())
|
| 206 |
+
# print(test_fscore)
|
| 207 |
+
# print(best_model)
|
| 208 |
+
best_epoch = e + 1
|
| 209 |
+
# torch.save(model.state_dict(), path + args.dataset_name + '/model_' + str(e) + '_' + str(test_acc)+ '.pkl')
|
| 210 |
+
|
| 211 |
+
e += 1
|
| 212 |
+
#save model
|
| 213 |
+
torch.save(best_model, path + args.dataset_name + '/model_' + str(best_epoch) + '_' + str(best_fscore) + '_' + str(
|
| 214 |
+
args.gnn_layers) + '.pkl')
|
| 215 |
+
# print(best_model)
|
| 216 |
+
if args.tensorboard:
|
| 217 |
+
writer.close()
|
| 218 |
+
|
| 219 |
+
logger.info('finish training!')
|
| 220 |
+
|
| 221 |
+
#print('Test performance..')
|
| 222 |
+
all_fscore = sorted(all_fscore, key=lambda x: (x[0],x[1]), reverse=True)
|
| 223 |
+
#print('Best F-Score based on validation:', all_fscore[0][1])
|
| 224 |
+
#print('Best F-Score based on test:', max([f[1] for f in all_fscore]))
|
| 225 |
+
|
| 226 |
+
#logger.info('Test performance..')
|
| 227 |
+
#logger.info('Best F-Score based on validation:{}'.format(all_fscore[0][1]))
|
| 228 |
+
#logger.info('Best F-Score based on test:{}'.format(max([f[1] for f in all_fscore])))
|
| 229 |
+
|
| 230 |
+
if args.dataset_name=='DailyDialog':
|
| 231 |
+
logger.info('Best micro/macro F-Score based on validation:{}/{}'.format(all_fscore[0][1],all_fscore[0][3]))
|
| 232 |
+
all_fscore = sorted(all_fscore, key=lambda x: x[1], reverse=True)
|
| 233 |
+
logger.info('Best micro/macro F-Score based on test:{}/{}'.format(all_fscore[0][1],all_fscore[0][3]))
|
| 234 |
+
else:
|
| 235 |
+
logger.info('Best F-Score based on validation:{}'.format(all_fscore[0][1]))
|
| 236 |
+
logger.info('Best F-Score based on test:{}'.format(max([f[1] for f in all_fscore])))
|
| 237 |
+
|
| 238 |
+
#save_badcase(best_model, test_loader, cuda, args, speaker_vocab, label_vocab)
|
| 239 |
+
|
saved_models/IEMOCAP/README.txt
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
请在此处存放训练后IEMOCAP的模型。
|
| 2 |
+
Please store the trained IEMOCAP model here.
|
saved_models/MELD/README.txt
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
请在此处存放训练后MELD的模型。
|
| 2 |
+
Please store the trained MELD model here.
|
saved_models/README.txt
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
请在此处的IEMOCAP和MELD文件夹中存放训练后的模型。
|
| 2 |
+
Please store the trained models in the IEMOCAP and MELD folders here.
|
similarity_matrix.py
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
emotion_positions = {
|
| 5 |
+
"pleasure": (np.cos(np.pi / 20), np.sin(np.pi / 20)),
|
| 6 |
+
"happiness": (np.cos(3 * np.pi / 20), np.sin(3 * np.pi / 20)),
|
| 7 |
+
"joy": (np.cos(3 * np.pi / 20), np.sin(3 * np.pi / 20)),
|
| 8 |
+
"pride": (np.cos(5 * np.pi / 20), np.sin(5 * np.pi / 20)),
|
| 9 |
+
"elation": (np.cos(5 * np.pi / 20), np.sin(5 * np.pi / 20)),
|
| 10 |
+
"excitement": (np.cos(7 * np.pi / 20), np.sin(7 * np.pi / 20)),
|
| 11 |
+
"surprise": (np.cos(9 * np.pi / 20), np.sin(9 * np.pi / 20)),
|
| 12 |
+
"interest": (np.cos(9 * np.pi / 20), np.sin(9 * np.pi / 20)),
|
| 13 |
+
"anger": (-np.cos(9 * np.pi / 20), np.sin(9 * np.pi / 20)),
|
| 14 |
+
"irritation": (-np.cos(9 * np.pi / 20), np.sin(9 * np.pi / 20)),
|
| 15 |
+
"hate": (-np.cos(7 * np.pi / 20), np.sin(7 * np.pi / 20)),
|
| 16 |
+
"contempt": (-np.cos(5 * np.pi / 20), np.sin(5 * np.pi / 20)),
|
| 17 |
+
"disgust": (-np.cos(3 * np.pi / 20), np.sin(3 * np.pi / 20)),
|
| 18 |
+
"fear": (-np.cos(np.pi / 20), np.sin(np.pi / 20)),
|
| 19 |
+
"boredom": (-0.5, 0),
|
| 20 |
+
"disappointment": (-np.cos(np.pi / 20), -np.sin(np.pi / 20)),
|
| 21 |
+
"frustration": (-np.cos(np.pi / 20), -np.sin(np.pi / 20)),
|
| 22 |
+
"shame": (-np.cos(3 * np.pi / 20), -np.sin(3 * np.pi / 20)),
|
| 23 |
+
"regret": (-np.cos(5 * np.pi / 20), -np.sin(5 * np.pi / 20)),
|
| 24 |
+
"guilt": (-np.cos(7 * np.pi / 20), -np.sin(7 * np.pi / 20)),
|
| 25 |
+
"sadness": (-np.cos(9 * np.pi / 20), -np.sin(9 * np.pi / 20)),
|
| 26 |
+
"compassion": (np.cos(9 * np.pi / 20), -np.sin(9 * np.pi / 20)),
|
| 27 |
+
"relief": (np.cos(7 * np.pi / 20), -np.sin(7 * np.pi / 20)),
|
| 28 |
+
"admiration": (np.cos(5 * np.pi / 20), -np.sin(5 * np.pi / 20)),
|
| 29 |
+
"love": (np.cos(3 * np.pi / 20), -np.sin(3 * np.pi / 20)),
|
| 30 |
+
"contentment": (np.cos(np.pi / 20), -np.sin(np.pi / 20)),
|
| 31 |
+
"neutral": (0, 0)
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
# 计算两个情感标签之间的余弦相似度 Compute the cosine similarity between two sentiment labels.
|
| 35 |
+
def cosine_similarity(p1, p2):
|
| 36 |
+
dot_product = np.dot(p1, p2)
|
| 37 |
+
norm_p1 = np.linalg.norm(p1)
|
| 38 |
+
norm_p2 = np.linalg.norm(p2)
|
| 39 |
+
if norm_p1 == 0 or norm_p2 == 0:
|
| 40 |
+
return 0.0
|
| 41 |
+
# print(norm_p2)
|
| 42 |
+
return dot_product / (norm_p1 * norm_p2)
|
| 43 |
+
|
| 44 |
+
# 计算情感标签之间的相似度矩阵 Compute the similarity matrix between sentiment labels.
|
| 45 |
+
def compute_similarity_matrix(emotion_positions, n_dataset):
|
| 46 |
+
emotions = list(emotion_positions.keys())
|
| 47 |
+
# print(emotions)
|
| 48 |
+
n = len(emotions)
|
| 49 |
+
N = n_dataset # 总情感标签数 Total number of sentiment labels.
|
| 50 |
+
# print(N)
|
| 51 |
+
similarity_matrix = np.zeros((n, n))
|
| 52 |
+
emotion_to_index = {emotion: idx for idx, emotion in enumerate(emotions)} # 标签到索引的映射 Mapping from labels to indices.
|
| 53 |
+
# 获取 "neutral" 标签的索引 Get the index of the label "neutral".
|
| 54 |
+
neutral_index = emotions.index("neutral")
|
| 55 |
+
|
| 56 |
+
for i in range(n):
|
| 57 |
+
for j in range(n):
|
| 58 |
+
if i != j:
|
| 59 |
+
v1 = emotion_positions[emotions[i]][0]
|
| 60 |
+
v2 = emotion_positions[emotions[j]][0]
|
| 61 |
+
# print(v1)
|
| 62 |
+
p1 = emotion_positions[emotions[i]]
|
| 63 |
+
# print(p1)
|
| 64 |
+
p2 = emotion_positions[emotions[j]]
|
| 65 |
+
# print(v1)
|
| 66 |
+
# 如果两个情感标签的价度极性相反,设相似度为0 If the valence polarities of two emotion labels are opposite, set their similarity to 0.
|
| 67 |
+
if v1 * v2 < 0:
|
| 68 |
+
similarity_matrix[i][j] = 0
|
| 69 |
+
elif v1 * v2 == 0: # valence 极性为 0
|
| 70 |
+
similarity_matrix[i][j] = 1 / N # 设置为 1/N set to 1/N
|
| 71 |
+
else:
|
| 72 |
+
|
| 73 |
+
# print(v1)
|
| 74 |
+
# print(v2)
|
| 75 |
+
# print('-----')
|
| 76 |
+
similarity_matrix[i][j] = max(cosine_similarity(np.array(p1), np.array(p2)), 0)
|
| 77 |
+
|
| 78 |
+
# 特殊处理 "neutral" 标签 Special handling for the "neutral" label.
|
| 79 |
+
for i in range(n):
|
| 80 |
+
if i != neutral_index:
|
| 81 |
+
similarity_matrix[neutral_index][i] = 1 / N # 与所有其他标签的相似度为 1/N The similarity between the "neutral" label and all other labels is set to 1/N.
|
| 82 |
+
similarity_matrix[i][neutral_index] = 1 / N
|
| 83 |
+
|
| 84 |
+
return similarity_matrix, emotion_to_index
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def get_similarity_matrix(dataset):
|
| 88 |
+
|
| 89 |
+
if dataset == 'IEMOCAP':
|
| 90 |
+
n = 6
|
| 91 |
+
similarity_matrix , emotion_to_index = compute_similarity_matrix(emotion_positions, n)
|
| 92 |
+
else:
|
| 93 |
+
n = 7
|
| 94 |
+
similarity_matrix , emotion_to_index = compute_similarity_matrix(emotion_positions, n)
|
| 95 |
+
|
| 96 |
+
#输出相似度矩阵
|
| 97 |
+
#绝对值越接近1表示越相似,越接近0表示越不一样 The closer the absolute value is to 1, the more similar they are; the closer it is to 0, the more different they are.
|
| 98 |
+
# print("Emotion Similarity Matrix:")
|
| 99 |
+
# print(similarity_matrix)
|
| 100 |
+
# print(emotion_to_index)
|
| 101 |
+
return similarity_matrix,emotion_to_index
|
trainer.py
ADDED
|
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np, argparse, time, pickle, random
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch.optim as optim
|
| 5 |
+
from torch.utils.data import DataLoader
|
| 6 |
+
from torch.utils.data.sampler import SubsetRandomSampler
|
| 7 |
+
from dataloader import IEMOCAPDataset
|
| 8 |
+
from sklearn.metrics import f1_score, confusion_matrix, accuracy_score, classification_report, \
|
| 9 |
+
precision_recall_fscore_support
|
| 10 |
+
from utils import person_embed
|
| 11 |
+
from tqdm import tqdm
|
| 12 |
+
import json
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def train_or_eval_model(model, loss_function, dataloader,epoch, cuda, args, optimizer=None, train=False):
|
| 16 |
+
losses, preds, labels = [], [], []
|
| 17 |
+
scores, vids = [], []
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
assert not train or optimizer != None
|
| 21 |
+
if train:
|
| 22 |
+
model.train()
|
| 23 |
+
# dataloader = tqdm(dataloader)
|
| 24 |
+
else:
|
| 25 |
+
model.eval()
|
| 26 |
+
|
| 27 |
+
cnt = 0
|
| 28 |
+
for data in dataloader:
|
| 29 |
+
if train:
|
| 30 |
+
optimizer.zero_grad()
|
| 31 |
+
# text_ids, text_feature, speaker_ids, labels, umask = [d.cuda() for d in data] if cuda else data
|
| 32 |
+
features, label, adj_1, adj_2, s_mask, s_mask_onehot,lengths, speakers, utterances = data
|
| 33 |
+
# speaker_vec = person_embed(speaker_ids, person_vec)
|
| 34 |
+
if cuda:
|
| 35 |
+
features = features.cuda()
|
| 36 |
+
label = label.cuda()
|
| 37 |
+
adj_1 = adj_1.cuda()
|
| 38 |
+
adj_2 = adj_2.cuda()
|
| 39 |
+
s_mask = s_mask.cuda()
|
| 40 |
+
s_mask_onehot = s_mask_onehot.cuda()
|
| 41 |
+
lengths = lengths.cuda()
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
# print(speakers)
|
| 45 |
+
log_prob, diff_loss = model(features, adj_1, adj_2, s_mask, s_mask_onehot, lengths) # (B, N, C)
|
| 46 |
+
# print(label)
|
| 47 |
+
loss = loss_function(log_prob.permute(0,2,1), label)+ diff_loss
|
| 48 |
+
'''
|
| 49 |
+
# print(speakers)
|
| 50 |
+
log_prob = model(features, adj_1, adj_2, s_mask, s_mask_onehot, lengths) # (B, N, C)
|
| 51 |
+
# print(label)
|
| 52 |
+
loss = loss_function(log_prob.permute(0,2,1), label)
|
| 53 |
+
'''
|
| 54 |
+
label = label.cpu().numpy().tolist()
|
| 55 |
+
pred = torch.argmax(log_prob, dim = 2).cpu().numpy().tolist()
|
| 56 |
+
preds += pred
|
| 57 |
+
labels += label
|
| 58 |
+
losses.append(loss.item())
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
if train:
|
| 62 |
+
loss_val = loss.item()
|
| 63 |
+
loss.backward()
|
| 64 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
|
| 65 |
+
if args.tensorboard:
|
| 66 |
+
for param in model.named_parameters():
|
| 67 |
+
writer.add_histogram(param[0], param[1].grad, epoch)
|
| 68 |
+
optimizer.step()
|
| 69 |
+
|
| 70 |
+
if preds != []:
|
| 71 |
+
new_preds = []
|
| 72 |
+
new_labels = []
|
| 73 |
+
for i,label in enumerate(labels):
|
| 74 |
+
for j,l in enumerate(label):
|
| 75 |
+
if l != -1:
|
| 76 |
+
new_labels.append(l)
|
| 77 |
+
new_preds.append(preds[i][j])
|
| 78 |
+
else:
|
| 79 |
+
return float('nan'), float('nan'), [], [], float('nan'), [], [], [], [], []
|
| 80 |
+
|
| 81 |
+
# print(preds.tolist())
|
| 82 |
+
# print(labels.tolist())
|
| 83 |
+
avg_loss = round(np.sum(losses) / len(losses), 4)
|
| 84 |
+
avg_accuracy = round(accuracy_score(new_labels, new_preds) * 100, 2)
|
| 85 |
+
if args.dataset_name in ['IEMOCAP', 'MELD', 'EmoryNLP']:
|
| 86 |
+
avg_fscore = round(f1_score(new_labels, new_preds, average='weighted') * 100, 2)
|
| 87 |
+
f1_per_class = f1_score(new_labels, new_preds, average=None) # List of F1 scores for each class
|
| 88 |
+
avg_macro_fscore = round(f1_score(new_labels, new_preds, average='macro') * 100, 2)
|
| 89 |
+
return avg_loss, avg_accuracy, labels, preds, avg_fscore, f1_per_class, avg_macro_fscore
|
| 90 |
+
else:
|
| 91 |
+
avg_micro_fscore = round(f1_score(new_labels, new_preds, average='micro', labels=list(range(1, 7))) * 100, 2)
|
| 92 |
+
avg_macro_fscore = round(f1_score(new_labels, new_preds, average='macro') * 100, 2)
|
| 93 |
+
return avg_loss, avg_accuracy, labels, preds, avg_micro_fscore, avg_macro_fscore
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def save_badcase(model, dataloader, cuda, args, speaker_vocab, label_vocab):
|
| 97 |
+
preds, labels = [], []
|
| 98 |
+
scores, vids = [], []
|
| 99 |
+
dialogs = []
|
| 100 |
+
speakers = []
|
| 101 |
+
|
| 102 |
+
model.eval()
|
| 103 |
+
|
| 104 |
+
for data in dataloader:
|
| 105 |
+
|
| 106 |
+
# text_ids, text_feature, speaker_ids, labels, umask = [d.cuda() for d in data] if cuda else data
|
| 107 |
+
features, label, adj,s_mask, s_mask_onehot,lengths, speaker, utterances = data
|
| 108 |
+
# speaker_vec = person_embed(speaker_ids, person_vec)
|
| 109 |
+
if cuda:
|
| 110 |
+
features = features.cuda()
|
| 111 |
+
label = label.cuda()
|
| 112 |
+
adj = adj.cuda()
|
| 113 |
+
s_mask_onehot = s_mask_onehot.cuda()
|
| 114 |
+
s_mask = s_mask.cuda()
|
| 115 |
+
lengths = lengths.cuda()
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
# print(speakers)
|
| 119 |
+
log_prob = model(features, adj,s_mask, s_mask_onehot, lengths) # (B, N, C)
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
label = label.cpu().numpy().tolist() # (B, N)
|
| 123 |
+
pred = torch.argmax(log_prob, dim = 2).cpu().numpy().tolist() # (B, N)
|
| 124 |
+
preds += pred
|
| 125 |
+
labels += label
|
| 126 |
+
dialogs += utterances
|
| 127 |
+
speakers += speaker
|
| 128 |
+
|
| 129 |
+
# finished here
|
| 130 |
+
|
| 131 |
+
if preds != []:
|
| 132 |
+
new_preds = []
|
| 133 |
+
new_labels = []
|
| 134 |
+
for i,label in enumerate(labels):
|
| 135 |
+
for j,l in enumerate(label):
|
| 136 |
+
if l != -1:
|
| 137 |
+
new_labels.append(l)
|
| 138 |
+
new_preds.append(preds[i][j])
|
| 139 |
+
else:
|
| 140 |
+
return
|
| 141 |
+
|
| 142 |
+
cases = []
|
| 143 |
+
for i,d in enumerate(dialogs):
|
| 144 |
+
case = []
|
| 145 |
+
for j,u in enumerate(d):
|
| 146 |
+
case.append({
|
| 147 |
+
'text': u,
|
| 148 |
+
'speaker': speaker_vocab['itos'][speakers[i][j]],
|
| 149 |
+
'label': label_vocab['itos'][labels[i][j]] if labels[i][j] != -1 else 'none',
|
| 150 |
+
'pred': label_vocab['itos'][preds[i][j]]
|
| 151 |
+
})
|
| 152 |
+
cases.append(case)
|
| 153 |
+
|
| 154 |
+
with open('badcase/%s.json'%(args.dataset_name), 'w', encoding='utf-8') as f:
|
| 155 |
+
json.dump(cases,f)
|
| 156 |
+
|
| 157 |
+
# print(preds.tolist())
|
| 158 |
+
# print(labels.tolist())
|
| 159 |
+
avg_accuracy = round(accuracy_score(new_labels, new_preds) * 100, 2)
|
| 160 |
+
if args.dataset_name in ['IEMOCAP', 'MELD', 'EmoryNLP']:
|
| 161 |
+
avg_fscore = round(f1_score(new_labels, new_preds, average='weighted') * 100, 2)
|
| 162 |
+
print('badcase saved')
|
| 163 |
+
print('test_f1', avg_fscore)
|
| 164 |
+
return
|
| 165 |
+
else:
|
| 166 |
+
avg_micro_fscore = round(f1_score(new_labels, new_preds, average='micro', labels=list(range(1, 7))) * 100, 2)
|
| 167 |
+
avg_macro_fscore = round(f1_score(new_labels, new_preds, average='macro') * 100, 2)
|
| 168 |
+
print('badcase saved')
|
| 169 |
+
print('test_micro_f1', avg_micro_fscore)
|
| 170 |
+
print('test_macro_f1', avg_macro_fscore)
|
| 171 |
+
return
|
utils.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
def person_embed(speaker_ids, person_vec):
|
| 5 |
+
'''
|
| 6 |
+
|
| 7 |
+
:param speaker_ids: torch.Tensor ( T, B)
|
| 8 |
+
:param person_vec: numpy array (num_speakers, 100)
|
| 9 |
+
:return:
|
| 10 |
+
speaker_vec: torch.Tensor (T, B, D)
|
| 11 |
+
'''
|
| 12 |
+
speaker_vec = []
|
| 13 |
+
for t in speaker_ids:
|
| 14 |
+
speaker_vec.append([person_vec[int(i)].tolist() if i != -1 else [0] * 100 for i in t])
|
| 15 |
+
speaker_vec = torch.FloatTensor(speaker_vec)
|
| 16 |
+
return speaker_vec
|