LiXinran1 commited on
Commit
26e4a00
·
verified ·
1 Parent(s): d44dc1c

Upload 33 files

Browse files
.gitattributes CHANGED
@@ -1,35 +1,4 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
1
+ # Auto detect text files and perform LF normalization
2
+ * text=auto
3
+ *.feature filter=lfs diff=lfs merge=lfs -text
4
+ *.json.feature filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ *.feature
2
+ *.pkl
README.md CHANGED
@@ -1,3 +1,26 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Long-Short Distance Graph Neural Networks and Improved Curriculum Learning for Emotion Recognition in Conversation (Accepted by ECAI2025)
2
+
3
+ Emotion Recognition in Conversation (ERC) is a practical and challenging task. This paper proposes a novel multimodal approach, the Long-Short Distance Graph Neural Network (LSDGNN). Based on the Directed Acyclic Graph (DAG), it constructs a long-distance graph neural network and a short-distance graph neural network to obtain multimodal features of distant and nearby utterances, respectively. To ensure that long- and short-distance features are as distinct as possible in representation while enabling mutual influence between the two modules, we employ a Differential Regularizer and incorporate a BiAffine Module to facilitate feature interaction. In addition, we propose an Improved Curriculum Learning (ICL) to address the challenge of data imbalance. By computing the similarity between different emotions to emphasize the shifts in similar emotions, we design a "weighted emotional shift" metric and develop a difficulty measurer, enabling a training process that prioritizes learning easy samples before harder ones. Experimental results on the IEMOCAP and MELD datasets demonstrate that our model outperforms existing benchmarks.
4
+
5
+ ## Requirements
6
+ Python 3.11
7
+ CUDA 12.2
8
+
9
+ After configuring the Python environment and CUDA, you can use `pip install -r requirements.txt` to install the following libraries.
10
+
11
+ torch==2.0.0+cu117
12
+ transformers==4.46.3
13
+ numpy==1.24.2
14
+ pandas==2.1.4
15
+ matplotlib==3.7.1
16
+ scikit-learn==1.2.2
17
+ tqdm==4.67.1
18
+
19
+ ### Training
20
+ GPU NVIDIA GeForce RTX 3090
21
+
22
+ for IEMOCAP:
23
+ `python run.py --dataset_name IEMOCAP --gnn_layers 4 --lr 0.0005 --batch_size 16 --epochs 30 --dropout 0.4 --emb_dim 2948 --windowpl 5 --diffloss 0.1 --curriculum --bucket_number 5`
24
+
25
+ for MELD:
26
+ `python run.py --dataset_name MELD --gnn_layers 2 --lr 0.00001 --batch_size 64 --epochs 30 --dropout 0.1 --emb_dim 1666 --windowpl 5 --diffloss 0.2 --curriculum --bucket_number 12`
cl.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import similarity_matrix
3
+
4
+ class Dialog:
5
+ def __init__(self, utterances, labels, speakers, features, dataset):
6
+ self.utterances = utterances
7
+ self.labels = labels
8
+ self.speakers = speakers
9
+ self.features = features
10
+ self.dataset = dataset
11
+ self.numberofemotionshifts = 0
12
+ self.numberofspeakers = 0
13
+ self.numberofutterances = 0
14
+ self.difficulty = 0
15
+ self.emotion_variance = 0 # 情感变化度量
16
+ self.emotion_shift_weighted = 0 #加权后的情感变化
17
+ self.cc()
18
+
19
+ def __getitem__(self, item):
20
+ if item == 'utterances':
21
+ return self.utterances
22
+ elif item == 'labels':
23
+ return self.labels
24
+ elif item == 'speakers':
25
+ return self.speakers
26
+ elif item == 'features':
27
+ return self.features
28
+ #measure the difficulty of a dialog
29
+ def cc(self):
30
+ # 情感数字到文字的映射字典
31
+ # print(self.dataset)
32
+ if self.dataset == 'MELD':
33
+ emotion_map = { -1: 'null', 0: 'neutral', 1: 'surprise', 2: 'fear', 3: 'sadness', 4: 'joy', 5: 'disgust', 6: 'anger'}
34
+ else:
35
+ emotion_map = { -1: 'null', 0:'excitement', 1: 'neutral', 2:'frustration', 3:'sadness', 4:'happiness', 5:'anger'}
36
+ # print(emotion_map)
37
+ self.numberofutterances = len(self.utterances)
38
+ speaker_emo = {}
39
+ for i in range(0, len(self.labels)):
40
+ if (self.speakers[i] in speaker_emo):
41
+ speaker_emo[self.speakers[i]].append(emotion_map[self.labels[i]])
42
+ else:
43
+ speaker_emo[self.speakers[i]] = [emotion_map[self.labels[i]]]
44
+
45
+ # 获取情感相似度矩阵
46
+ matrix, emotion_to_index = similarity_matrix.get_similarity_matrix(self.dataset)
47
+ # print(matrix)
48
+ k = 1
49
+ b = 0.4
50
+ for key in speaker_emo:
51
+ # prev_emo = None
52
+ for i in range(0, len(speaker_emo[key]) - 1):
53
+ current_emo = speaker_emo[key][i]
54
+ next_emo = speaker_emo[key][i + 1]
55
+ if current_emo != next_emo and current_emo != 'null' and next_emo != 'null':
56
+ self.numberofemotionshifts += 1
57
+ current_emo_index = emotion_to_index[current_emo]
58
+ next_emo_index = emotion_to_index[next_emo]
59
+ #线性缩放
60
+ #当k为正数时,similarity_score越小说明差距越大,越大说明差距越小,侧重于差距小的情感
61
+ #当k为负数时,反之,侧重于差距大的情感
62
+ similarity_score = abs(matrix[current_emo_index][next_emo_index]) * k + b
63
+ self.emotion_shift_weighted += similarity_score
64
+
65
+ #print(speaker_emo[key])
66
+ '''
67
+ for key in speaker_emo:
68
+ # Convert labels to indices
69
+ emotions = speaker_emo[key]
70
+ self.emotion_variance += np.std(emotions) # 计算每个发言人的情感方差
71
+ '''
72
+ # print(self.numberofemotionshifts)
73
+ # print(self.emotion_shift_weighted)
74
+ # print('---------')
75
+ self.numberofspeakers = len(set(self.speakers))
76
+ self.difficulty = (self.emotion_shift_weighted + self.numberofspeakers ) / (self.numberofutterances + self.numberofspeakers)
77
+ # self.difficulty = (self.numberofemotionshifts + self.numberofspeakers ) / (self.numberofutterances + self.numberofspeakers)
data/IEMOCAP/dev_data_roberta.json.feature ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c8c1835251e4c85af23e65d1a312990b411d36574f9bd1063f6ca3d20d8f2eda
3
+ size 32689394
data/IEMOCAP/dev_data_roberta_mm.json.feature ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6bd2b4d3fd2a092c2f7b390a0c5115533bf4fb5b0724f49e76a84d98bf054bdb
3
+ size 62364912
data/IEMOCAP/label_vocab.pkl ADDED
Binary file (98 Bytes). View file
 
data/IEMOCAP/speaker_vocab.pkl ADDED
Binary file (7.61 kB). View file
 
data/IEMOCAP/test_data_roberta.json.feature ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6f7c81d1690f997cd02672b3dd1f077744c997c0446b9ce6599a54f37db4950f
3
+ size 51258384
data/IEMOCAP/test_data_roberta_mm.json.feature ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5bcf5c03d714bbf27105e72141a6e78f1d58278d794e46be82cee53beddcfab1
3
+ size 103033265
data/IEMOCAP/train_data_roberta.json.feature ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6299f2656d0d4e7a7c2e1f91c7e0f811088d500c6320b741ce742a72d6993c99
3
+ size 151409987
data/IEMOCAP/train_data_roberta_mm.json.feature ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cdc380726af732912b539b335233053a465237d29c7369477851bf8dc62e5f28
3
+ size 307542106
data/MELD/dev_data_roberta.json.feature ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2fe9a5f3b0464ff597061a833fa62cfe2337311b53699d1878bba1827106a0a8
3
+ size 29797677
data/MELD/dev_data_roberta_mm.json.feature ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:01930f39300566f37d433a60c323eedc2a8cfd53486d5f5cbda56371e85f8aff
3
+ size 42149762
data/MELD/label_vocab.pkl ADDED
Binary file (128 Bytes). View file
 
data/MELD/speaker_vocab.pkl ADDED
Binary file (5.08 kB). View file
 
data/MELD/test_data_roberta.json.feature ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0a60fe91031fe79749fa813336309ae5a1804561ff70dd5a6320029a4fdee0f0
3
+ size 70132507
data/MELD/test_data_roberta_mm.json.feature ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0cead63677c699daef4d382047e5592066922c6659d5bca9bc5ce57367dd65a7
3
+ size 99010213
data/MELD/train_data_roberta.json.feature ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:05d4c1be081d894d9f2baa44e3a86882ccaa77cd0783e9c9092b39ad31d3bc81
3
+ size 267972749
data/MELD/train_data_roberta_mm.json.feature ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:40b44a9d69baccc275fd3dcbb95aaf079ac32199d32b072c652e7691a5b46bf8
3
+ size 364368202
dataloader.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataset import *
2
+ import pickle
3
+ from torch.utils.data.sampler import SubsetRandomSampler
4
+ from torch.utils.data import DataLoader
5
+ import os
6
+ import argparse
7
+ import numpy as np
8
+ from transformers import BertTokenizer
9
+
10
+ def get_train_valid_sampler(trainset):
11
+ size = len(trainset)
12
+ idx = list(range(size))
13
+ return SubsetRandomSampler(idx)
14
+
15
+
16
+ def load_vocab(dataset_name):
17
+ speaker_vocab = pickle.load(open('data/%s/speaker_vocab.pkl' % (dataset_name), 'rb'))
18
+ label_vocab = pickle.load(open('data/%s/label_vocab.pkl' % (dataset_name), 'rb'))
19
+ person_vec_dir = 'data/%s/person_vect.pkl' % (dataset_name)
20
+ # if os.path.exists(person_vec_dir):
21
+ # print('Load person vec from ' + person_vec_dir)
22
+ # person_vec = pickle.load(open(person_vec_dir, 'rb'))
23
+ # else:
24
+ # print('Creating personality vectors')
25
+ # person_vec = np.random.randn(len(speaker_vocab['itos']), 100)a
26
+ # print('Saving personality vectors to' + person_vec_dir)
27
+ # with open(person_vec_dir,'wb') as f:
28
+ # pickle.dump(person_vec, f, -1)
29
+ person_vec = None
30
+
31
+ return speaker_vocab, label_vocab, person_vec
32
+
33
+
34
+ def get_IEMOCAP_loaders(dataset_name = 'IEMOCAP', batch_size=32, num_workers=0, pin_memory=False, args = None):
35
+ print('building vocab.. ')
36
+ speaker_vocab, label_vocab, person_vec = load_vocab(dataset_name)
37
+ print('building datasets..')
38
+ devset = IEMOCAPDataset(dataset_name, 'dev', speaker_vocab, label_vocab, args)
39
+ valid_sampler = get_train_valid_sampler(devset)
40
+ testset = IEMOCAPDataset(dataset_name, 'test', speaker_vocab, label_vocab, args)
41
+ valid_loader = DataLoader(devset,
42
+ batch_size=batch_size,
43
+ sampler=valid_sampler,
44
+ collate_fn=devset.collate_fn,
45
+ num_workers=num_workers,
46
+ pin_memory=pin_memory)
47
+
48
+ test_loader = DataLoader(testset,
49
+ batch_size=batch_size,
50
+ collate_fn=testset.collate_fn,
51
+ num_workers=num_workers,
52
+ pin_memory=pin_memory)
53
+
54
+ return valid_loader, test_loader, speaker_vocab, label_vocab, person_vec
55
+ #adding babystep_index argument if using curriculum learning
56
+ def get_train_loader(dataset_name = 'IEMOCAP', batch_size=32, num_workers=0, pin_memory=False, args = None, babystep_index = None):
57
+ print('building vocab.. ')
58
+ speaker_vocab, label_vocab, person_vec = load_vocab(dataset_name)
59
+ print('building datasets..')
60
+ if (args.curriculum):
61
+ trainset = IEMOCAPDataset(dataset_name, 'train', speaker_vocab, label_vocab, args, None, babystep_index)
62
+ train_sampler = get_train_valid_sampler(trainset)
63
+ train_loader = DataLoader(trainset,
64
+ batch_size=batch_size,
65
+ sampler=train_sampler,
66
+ collate_fn=trainset.collate_fn,
67
+ num_workers=num_workers,
68
+ pin_memory=pin_memory)
69
+ # train_loaders.append(train_loader)
70
+ else:
71
+ trainset = IEMOCAPDataset(dataset_name, 'train', speaker_vocab, label_vocab, args)
72
+ train_sampler = get_train_valid_sampler(trainset)
73
+ train_loader = DataLoader(trainset,
74
+ batch_size=batch_size,
75
+ sampler=train_sampler,
76
+ collate_fn=trainset.collate_fn,
77
+ num_workers=num_workers,
78
+ pin_memory=pin_memory)
79
+ # train_loaders.append(train_loader)
80
+
81
+ return train_loader
dataset.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils.data import Dataset
3
+ from torch.nn.utils.rnn import pad_sequence
4
+ import pickle, pandas as pd
5
+ import json
6
+ import numpy as np
7
+ import random
8
+ from pandas import DataFrame
9
+
10
+ import cl
11
+
12
+
13
+ class IEMOCAPDataset(Dataset):
14
+ #babystep_index:用于curriculum learning中的索引,指示在babystep方法中使用多少个“桶”。
15
+ def __init__(self, dataset_name = 'IEMOCAP', split = 'train', speaker_vocab=None, label_vocab=None, args = None, tokenizer = None, babystep_index = None):
16
+ self.speaker_vocab = speaker_vocab #说话者的词汇表
17
+ self.label_vocab = label_vocab #标签的词汇表
18
+ self.args = args #保存其他配置参数
19
+ self.data = self.read(dataset_name, split, tokenizer)#调用read方法加载数据集
20
+ if args.curriculum and split == 'train': #这行代码处理curriculum learning 如果args.curriculum为True且当前数据是训练集(split == 'train'),则进行babystep操作。
21
+ self.data = self.babystep(self.getbuckets(self.data, args.bucket_number), babystep_index)
22
+ print(len(self.data))
23
+
24
+ self.len = len(self.data)
25
+
26
+ def read(self, dataset_name, split, tokenizer):
27
+ with open('data/%s/%s_data_roberta_mm.json.feature'%(dataset_name, split), encoding='utf-8') as f:
28
+ raw_data = json.load(f)
29
+
30
+ # process dialogue
31
+ dialogs = []
32
+ # raw_data = sorted(raw_data, key=lambda x:len(x))
33
+ for d in raw_data:
34
+ # if len(d) < 5 or len(d) > 6:
35
+ # continue
36
+ utterances = []
37
+ labels = []
38
+ speakers = []
39
+ features = []
40
+ for i,u in enumerate(d):
41
+ utterances.append(u['text'])
42
+ labels.append(self.label_vocab['stoi'][u['label']] if 'label' in u.keys() else -1)
43
+ speakers.append(self.speaker_vocab['stoi'][u['speaker']])
44
+ features.append(u['cls'][0] + u['cls'][1]+u['cls'][2])
45
+ #different modalities
46
+ #features.append(u['cls'][0])
47
+ #features.append(u['cls'][1])
48
+ #features.append(u['cls'][2])
49
+ #features.append(u['cls'][0] + u['cls'][1])
50
+ #features.append(u['cls'][0] + u['cls'][2])
51
+ #features.append(u['cls'][1]+u['cls'][2])
52
+ dialog = cl.Dialog(utterances, labels, speakers, features, self.args.dataset_name)
53
+ # dialogs.append({
54
+ # 'utterances': utterances,
55
+ # 'labels': labels,
56
+ # 'speakers':speakers,
57
+ # 'features': features
58
+ # })
59
+ dialogs.append(dialog)
60
+ if self.args.curriculum and split == 'train':
61
+ totalut = 0
62
+ totalshift = 0
63
+ totalspeaker = 0
64
+ for i in range(0, len(dialogs)):
65
+ totalut += dialogs[i].numberofutterances
66
+ totalshift += dialogs[i].numberofemotionshifts
67
+ totalspeaker += dialogs[i].numberofspeakers
68
+ # random.shuffle(dialogs)
69
+ dialogs.sort(key= lambda dialog: dialog.difficulty)
70
+ # if (split == 'train'):
71
+ # num_buckets = 8
72
+ # bucket_length = (len(dialogs) + num_buckets - 1) // num_buckets
73
+ # buckets = [dialogs[i:i + bucket_length] for i in range(0, len(dialogs), bucket_length)]
74
+ # print('')
75
+ else:
76
+ random.shuffle(dialogs)
77
+ return dialogs
78
+
79
+ def getbuckets(self, dialogs, num_buckets):
80
+ buckets = []
81
+ bucket_length = (len(dialogs) + num_buckets - 1) // num_buckets
82
+ buckets = [dialogs[i:i + bucket_length] for i in range(0, len(dialogs), bucket_length)]
83
+ print('bucket')
84
+ print(len(buckets))
85
+ return buckets
86
+ #parameter for curriculum learning
87
+ def babystep(self, buckets, index):
88
+ data = []
89
+ for i in range(0, index):
90
+ data+= buckets[i];
91
+ return data
92
+ def __getitem__(self, index):
93
+ '''
94
+ :param index:
95
+ :return:
96
+ feature,
97
+ label
98
+ speaker
99
+ length
100
+ text
101
+ '''
102
+ return torch.FloatTensor(self.data[index]['features']), \
103
+ torch.LongTensor(self.data[index]['labels']),\
104
+ self.data[index]['speakers'], \
105
+ len(self.data[index]['labels']), \
106
+ self.data[index]['utterances']
107
+
108
+ def __len__(self):
109
+ return self.len
110
+
111
+ def get_adj(self, speakers, max_dialog_len):
112
+ '''
113
+ get adj matrix
114
+ :param speakers: (B, N)
115
+ :param max_dialog_len:
116
+ :return:
117
+ adj: (B, N, N) adj[:,i,:] means the direct predecessors of node i
118
+ '''
119
+ adj = []
120
+ for speaker in speakers:
121
+ a = torch.zeros(max_dialog_len, max_dialog_len)
122
+ for i,s in enumerate(speaker):
123
+ get_local_pred = False
124
+ get_global_pred = False
125
+ for j in range(i - 1, -1, -1):
126
+ if speaker[j] == s and not get_local_pred:
127
+ get_local_pred = True
128
+ a[i,j] = 1
129
+ elif speaker[j] != s and not get_global_pred:
130
+ get_global_pred = True
131
+ a[i,j] = 1
132
+ if get_global_pred and get_local_pred:
133
+ break
134
+ adj.append(a)
135
+ return torch.stack(adj)
136
+
137
+ def get_adj_v1(self, speakers, max_dialog_len):
138
+ '''
139
+ get adj matrix
140
+ :param speakers: (B, N)
141
+ :param max_dialog_len:
142
+ :return:
143
+ adj: (B, N, N) adj[:,i,:] means the direct predecessors of node i
144
+ '''
145
+ adj = []
146
+ for speaker in speakers:
147
+ a = torch.zeros(max_dialog_len, max_dialog_len)
148
+ for i,s in enumerate(speaker):
149
+ cnt = 0
150
+ for j in range(i - 1, -1, -1):
151
+ a[i,j] = 1
152
+ if speaker[j] == s:
153
+ cnt += 1
154
+ if cnt==self.args.windowps:
155
+ break
156
+ adj.append(a)
157
+ return torch.stack(adj)
158
+
159
+
160
+ def get_adj_v2(self, speakers, max_dialog_len):
161
+ '''
162
+ get adj matrix
163
+ :param speakers: (N)
164
+ :param max_dialog_len:
165
+ :return:
166
+ adj: (N, N) adj[i,:] means the direct predecessors of node i
167
+ '''
168
+ adj = []
169
+ for speaker in speakers:
170
+ a = torch.zeros(max_dialog_len, max_dialog_len)
171
+ for i, s in enumerate(speaker):
172
+ cnt = 0
173
+ for j in range(i - 1, -1, -1):
174
+ a[i, j] = 1 # Assign 1 for all previous utterances
175
+ if speaker[j] == s: # Compare speaker strings
176
+ cnt += 1
177
+ if cnt == self.args.windowpl: # Check if window condition is met
178
+ break
179
+ adj.append(a)
180
+
181
+ return torch.stack(adj)
182
+
183
+ def get_s_mask(self, speakers, max_dialog_len):
184
+ '''
185
+ :param speakers:
186
+ :param max_dialog_len:
187
+ :return:
188
+ s_mask: (B, N, N) s_mask[:,i,:] means the speaker informations for predecessors of node i, where 1 denotes the same speaker, 0 denotes the different speaker
189
+ s_mask_onehot (B, N, N, 2) onehot emcoding of s_mask
190
+ '''
191
+ s_mask = []
192
+ s_mask_onehot = []
193
+ for speaker in speakers:
194
+ s = torch.zeros(max_dialog_len, max_dialog_len, dtype = torch.long)
195
+ s_onehot = torch.zeros(max_dialog_len, max_dialog_len, 2)
196
+ for i in range(len(speaker)):
197
+ for j in range(len(speaker)):
198
+ if speaker[i] == speaker[j]:
199
+ s[i,j] = 1
200
+ s_onehot[i,j,1] = 1
201
+ else:
202
+ s_onehot[i,j,0] = 1
203
+
204
+ s_mask.append(s)
205
+ s_mask_onehot.append(s_onehot)
206
+ return torch.stack(s_mask), torch.stack(s_mask_onehot)
207
+
208
+ def collate_fn(self, data):
209
+ '''
210
+ :param data:
211
+ features, labels, speakers, length, utterances
212
+ :return:
213
+ features: (B, N, D) padded
214
+ labels: (B, N) padded
215
+ adj: (B, N, N) adj[:,i,:] means the direct predecessors of node i
216
+ s_mask: (B, N, N) s_mask[:,i,:] means the speaker informations for predecessors of node i, where 1 denotes the same speaker, 0 denotes the different speaker
217
+ lengths: (B, )
218
+ utterances: not a tensor
219
+ '''
220
+ max_dialog_len = max([d[3] for d in data])
221
+ feaures = pad_sequence([d[0] for d in data], batch_first = True) # (B, N, D)
222
+ labels = pad_sequence([d[1] for d in data], batch_first = True, padding_value = -1) # (B, N )
223
+ adj_1 = self.get_adj_v1([d[2] for d in data], max_dialog_len)
224
+ adj_2 = self.get_adj_v2([d[2] for d in data], max_dialog_len)
225
+ s_mask, s_mask_onehot = self.get_s_mask([d[2] for d in data], max_dialog_len)
226
+ lengths = torch.LongTensor([d[3] for d in data])
227
+ speakers = pad_sequence([torch.LongTensor(d[2]) for d in data], batch_first = True, padding_value = -1)
228
+ utterances = [d[4] for d in data]
229
+
230
+ return feaures, labels, adj_1, adj_2, s_mask, s_mask_onehot,lengths, speakers, utterances
evaluate.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.environ["CUDA_VISIBLE_DEVICES"] = '0'
3
+ import numpy as np, argparse, time, pickle, random
4
+ import torch
5
+ import matplotlib
6
+ import torch.nn as nn
7
+ import torch.optim as optim
8
+ from dataloader import IEMOCAPDataset
9
+ from model import *
10
+ from sklearn.metrics import f1_score, confusion_matrix, accuracy_score, classification_report, \
11
+ precision_recall_fscore_support, ConfusionMatrixDisplay
12
+ import matplotlib.pyplot as plt
13
+ from trainer import train_or_eval_model, save_badcase
14
+ from dataset import IEMOCAPDataset
15
+ from dataloader import get_IEMOCAP_loaders
16
+ from transformers import AdamW
17
+ import copy
18
+
19
+ # We use seed = 100 for reproduction of the results reported in the paper.
20
+ seed = 100
21
+
22
+
23
+ def seed_everything(seed=seed):
24
+ random.seed(seed)
25
+ np.random.seed(seed)
26
+ torch.manual_seed(seed)
27
+ torch.cuda.manual_seed(seed)
28
+ torch.cuda.manual_seed_all(seed)
29
+ torch.backends.cudnn.benchmark = False
30
+ torch.backends.cudnn.deterministic = True
31
+
32
+
33
+ def evaluate(model, dataloader, cuda, args, speaker_vocab, label_vocab):
34
+ preds, labels = [], []
35
+ scores, vids = [], []
36
+ dialogs = []
37
+ speakers = []
38
+
39
+ model.eval()
40
+
41
+ for data in dataloader:
42
+
43
+ features, label, adj,s_mask, s_mask_onehot,lengths, speaker, utterances = data
44
+ if cuda:
45
+ features = features.cuda()
46
+ label = label.cuda()
47
+ adj = adj.cuda()
48
+ s_mask_onehot = s_mask_onehot.cuda()
49
+ s_mask = s_mask.cuda()
50
+ lengths = lengths.cuda()
51
+
52
+ log_prob = model(features, adj,s_mask, s_mask_onehot, lengths) # (B, N, C)
53
+
54
+ label = label.cpu().numpy().tolist() # (B, N)
55
+ pred = torch.argmax(log_prob, dim = 2).cpu().numpy().tolist() # (B, N)
56
+ preds += pred
57
+ labels += label
58
+ dialogs += utterances
59
+ speakers += speaker
60
+
61
+ if preds != []:
62
+ new_preds = []
63
+ new_labels = []
64
+ for i,label in enumerate(labels):
65
+ for j,l in enumerate(label):
66
+ if l != -1:
67
+ new_labels.append(l)
68
+ new_preds.append(preds[i][j])
69
+ else:
70
+ return
71
+
72
+ avg_accuracy = round(accuracy_score(new_labels, new_preds) * 100, 2)
73
+ if args.dataset_name in ['IEMOCAP', 'MELD', 'EmoryNLP']:
74
+ avg_fscore = round(f1_score(new_labels, new_preds, average='weighted') * 100, 2)
75
+ # get f1 score for each class to generate confusion matrix
76
+ # fscore_perclass = f1_score(new_labels, new_preds, average=None)
77
+ # print('fscore_perclass', fscore_perclass)
78
+ print('test_accuracy', avg_accuracy)
79
+ print('test_f1', avg_fscore)
80
+ # confusion matrix test, not working on colab
81
+ # print(new_labels)
82
+ # cm = confusion_matrix(new_labels, new_preds, labels=[0, 1, 2, 3, 4, 5, 6])
83
+ # print(cm)
84
+ # per_class_accuracies = {}
85
+ #
86
+ # # Calculate the accuracy for each one of our classes
87
+ # for idx, cls in enumerate(label_vocab['itos']):
88
+ # # True negatives are all the samples that are not our current GT class (not the current row)
89
+ # # and were not predicted as the current class (not the current column)
90
+ # true_negatives = np.sum(np.delete(np.delete(cm, idx, axis=0), idx, axis=1))
91
+ #
92
+ # # True positives are all the samples of our current GT class that were predicted as such
93
+ # true_positives = cm[idx, idx]
94
+ #
95
+ # # The accuracy for the current class is the ratio between correct predictions to all predictions
96
+ # per_class_accuracies[cls] = (true_positives + true_negatives) / np.sum(cm)
97
+ # print('acc:', per_class_accuracies)
98
+ # disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=label_vocab['itos'])
99
+ # disp.plot()
100
+ # plt.show()
101
+ return
102
+ else:
103
+ avg_micro_fscore = round(f1_score(new_labels, new_preds, average='micro', labels=list(range(1, 7))) * 100, 2)
104
+ avg_macro_fscore = round(f1_score(new_labels, new_preds, average='macro') * 100, 2)
105
+ print('test_accuracy', avg_accuracy)
106
+ print('test_micro_f1', avg_micro_fscore)
107
+ print('test_macro_f1', avg_macro_fscore)
108
+ return
109
+
110
+ if __name__ == '__main__':
111
+
112
+ #path = './saved_models/'
113
+
114
+ parser = argparse.ArgumentParser()
115
+ parser.add_argument('--bert_model_dir', type=str, default='')
116
+ parser.add_argument('--bert_tokenizer_dir', type=str, default='')
117
+
118
+ parser.add_argument('--state_dict_file', type=str, default='')
119
+
120
+ parser.add_argument('--bert_dim', type = int, default=1024)
121
+ parser.add_argument('--hidden_dim', type = int, default=300)
122
+ parser.add_argument('--mlp_layers', type=int, default=2, help='Number of output mlp layers.')
123
+ parser.add_argument('--gnn_layers', type=int, default=2, help='Number of gnn layers.')
124
+ parser.add_argument('--emb_dim', type=int, default=1024, help='Feature size.')
125
+
126
+ parser.add_argument('--attn_type', type=str, default='rgcn', choices=['dotprod','linear','bilinear', 'rgcn'], help='Feature size.')
127
+ parser.add_argument('--no_rel_attn', action='store_true', default=False, help='no relation for edges' )
128
+
129
+ parser.add_argument('--max_sent_len', type=int, default=200,
130
+ help='max content length for each text, if set to 0, then the max length has no constrain')
131
+
132
+ parser.add_argument('--no_cuda', action='store_true', default=False, help='does not use GPU')
133
+
134
+ parser.add_argument('--dataset_name', default='IEMOCAP', type= str, help='dataset name, IEMOCAP or MELD or DailyDialog')
135
+
136
+ parser.add_argument('--windowp', type=int, default=1,
137
+ help='context window size for constructing edges in graph model for past utterances')
138
+
139
+ parser.add_argument('--windowf', type=int, default=0,
140
+ help='context window size for constructing edges in graph model for future utterances')
141
+
142
+ parser.add_argument('--max_grad_norm', type=float, default=5.0, help='Gradient clipping.')
143
+
144
+ parser.add_argument('--lr', type=float, default=1e-3, metavar='LR', help='learning rate')
145
+
146
+
147
+ parser.add_argument('--dropout', type=float, default=0, metavar='dropout', help='dropout rate')
148
+
149
+ parser.add_argument('--batch_size', type=int, default=8, metavar='BS', help='batch size')
150
+
151
+ parser.add_argument('--epochs', type=int, default=20, metavar='E', help='number of epochs')
152
+
153
+ parser.add_argument('--tensorboard', action='store_true', default=False, help='Enables tensorboard log')
154
+
155
+ parser.add_argument('--nodal_att_type', type=str, default=None, choices=['global', 'past'],
156
+ help='type of nodal attention')
157
+
158
+ parser.add_argument('--curriculum', action='store_true', default=False, help='Enables curriculum learning')
159
+
160
+ parser.add_argument('--bucket_number', type=int, default=0, help='Number of buckets using')
161
+
162
+ args = parser.parse_args()
163
+ print(args)
164
+
165
+ seed_everything()
166
+
167
+ args.cuda = torch.cuda.is_available() and not args.no_cuda
168
+
169
+ if args.cuda:
170
+ print('Running on GPU')
171
+ else:
172
+ print('Running on CPU')
173
+
174
+ if args.tensorboard:
175
+ from tensorboardX import SummaryWriter
176
+
177
+ writer = SummaryWriter()
178
+
179
+
180
+ cuda = args.cuda
181
+ n_epochs = args.epochs
182
+ batch_size = args.batch_size
183
+ valid_loader, test_loader, speaker_vocab, label_vocab, person_vec = get_IEMOCAP_loaders(
184
+ dataset_name=args.dataset_name, batch_size=batch_size, num_workers=0, args=args)
185
+ n_classes = len(label_vocab['itos'])
186
+
187
+ print('building model..')
188
+ model = DAGERC_fushion(args, n_classes)
189
+
190
+
191
+ if torch.cuda.device_count() > 1:
192
+ print('Multi-GPU...........')
193
+ model = nn.DataParallel(model,device_ids = range(torch.cuda.device_count()))
194
+ if cuda:
195
+ model.cuda()
196
+
197
+ state_dict = torch.load(args.state_dict_file)
198
+ model.load_state_dict(state_dict)
199
+ evaluate(model, test_loader, cuda, args, speaker_vocab, label_vocab)
model.py ADDED
@@ -0,0 +1,1199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import numpy as np, itertools, random, copy, math
5
+ from transformers import BertModel, BertConfig
6
+ from transformers import AutoTokenizer, AutoModelWithLMHead
7
+ from model_utils import *
8
+
9
+
10
+ class BertERC(nn.Module):
11
+
12
+ def __init__(self, args, num_class):
13
+ super().__init__()
14
+ self.args = args
15
+ # gcn layer
16
+
17
+ self.dropout = nn.Dropout(args.dropout)
18
+ # bert_encoder
19
+ self.bert_config = BertConfig.from_json_file(args.bert_model_dir + 'config.json')
20
+
21
+ self.bert = BertModel.from_pretrained(args.home_dir + args.bert_model_dir, config = self.bert_config)
22
+ in_dim = args.bert_dim
23
+
24
+ # output mlp layers
25
+ layers = [nn.Linear(in_dim, args.hidden_dim), nn.ReLU()]
26
+ for _ in range(args.mlp_layers- 1):
27
+ layers += [nn.Linear(args.hidden_dim, args.hidden_dim), nn.ReLU()]
28
+ layers += [nn.Linear(args.hidden_dim, num_class)]
29
+
30
+ self.out_mlp = nn.Sequential(*layers)
31
+
32
+ def forward(self, content_ids, token_types,utterance_len,seq_len):
33
+
34
+ # the embeddings for bert
35
+ # if len(content_ids)>512:
36
+ # print('ll')
37
+
38
+ #
39
+ ## w token_type_ids
40
+ # lastHidden = self.bert(content_ids, token_type_ids = token_types)[1] #(N , D)
41
+ ## w/t token_type_ids
42
+ lastHidden = self.bert(content_ids)[1] #(N , D)
43
+
44
+ final_feature = self.dropout(lastHidden)
45
+
46
+ # pooling
47
+
48
+ outputs = self.out_mlp(final_feature) #(N, D)
49
+
50
+ return outputs
51
+
52
+
53
+ class DAGERC(nn.Module):
54
+
55
+ def __init__(self, args, num_class):
56
+ super().__init__()
57
+ self.args = args
58
+ # gcn layer
59
+
60
+ self.dropout = nn.Dropout(args.dropout)
61
+
62
+ self.gnn_layers = args.gnn_layers
63
+
64
+ if not args.no_rel_attn:
65
+ self.rel_emb = nn.Embedding(2,args.hidden_dim)
66
+ self.rel_attn = True
67
+ else:
68
+ self.rel_attn = False
69
+
70
+ if self.args.attn_type == 'linear':
71
+ gats = []
72
+ for _ in range(args.gnn_layers):
73
+ gats += [GatLinear(args.hidden_dim) if args.no_rel_attn else GatLinear_rel(args.hidden_dim)]
74
+ self.gather = nn.ModuleList(gats)
75
+ else:
76
+ gats = []
77
+ for _ in range(args.gnn_layers):
78
+ gats += [Gatdot(args.hidden_dim) if args.no_rel_attn else Gatdot_rel(args.hidden_dim)]
79
+ self.gather = nn.ModuleList(gats)
80
+
81
+ grus = []
82
+ for _ in range(args.gnn_layers):
83
+ grus += [nn.GRUCell(args.hidden_dim, args.hidden_dim)]
84
+ self.grus = nn.ModuleList(grus)
85
+
86
+ self.fc1 = nn.Linear(args.emb_dim, args.hidden_dim)
87
+
88
+ in_dim = args.hidden_dim * (args.gnn_layers + 1) + args.emb_dim
89
+ # output mlp layers
90
+ layers = [nn.Linear(in_dim, args.hidden_dim), nn.ReLU()]
91
+ for _ in range(args.mlp_layers - 1):
92
+ layers += [nn.Linear(args.hidden_dim, args.hidden_dim), nn.ReLU()]
93
+ layers += [nn.Linear(args.hidden_dim, num_class)]
94
+
95
+ self.out_mlp = nn.Sequential(*layers)
96
+
97
+ def forward(self, features, adj,s_mask):
98
+ '''
99
+ :param features: (B, N, D)
100
+ :param adj: (B, N, N)
101
+ :param s_mask: (B, N, N)
102
+ :return:
103
+ '''
104
+ num_utter = features.size()[1]
105
+ if self.rel_attn:
106
+ rel_ft = self.rel_emb(s_mask) # (B, N, N, D)
107
+
108
+ H0 = F.relu(self.fc1(features)) # (B, N, D)
109
+ H = [H0]
110
+ for l in range(self.args.gnn_layers):
111
+ H1 = self.grus[l](H[l][:,0,:]).unsqueeze(1) # (B, 1, D)
112
+ for i in range(1, num_utter):
113
+ if not self.rel_attn:
114
+ _, M = self.gather[l](H[l][:,i,:], H1, H1, adj[:,i,:i])
115
+ else:
116
+ _, M = self.gather[l](H[l][:, i, :], H1, H1, adj[:, i, :i], rel_ft[:, i, :i, :])
117
+ H1 = torch.cat((H1 , self.grus[l](H[l][:,i,:], M).unsqueeze(1)), dim = 1)
118
+ # print('H1', H1.size())
119
+ # print('----------------------------------------------------')
120
+ H.append(H1)
121
+ H0 = H1
122
+ H.append(features)
123
+ H = torch.cat(H, dim = 2) #(B, N, l*D)
124
+ logits = self.out_mlp(H)
125
+ return logits
126
+
127
+
128
+
129
+ class DAGERC_fushion(nn.Module):
130
+
131
+ def __init__(self, args, num_class):
132
+ super().__init__()
133
+ self.args = args
134
+ # gcn layer
135
+
136
+ self.dropout = nn.Dropout(args.dropout)
137
+
138
+ self.gnn_layers = args.gnn_layers
139
+
140
+ if not args.no_rel_attn:
141
+ self.rel_attn = True
142
+ else:
143
+ self.rel_attn = False
144
+
145
+ if self.args.attn_type == 'linear':
146
+ gats = []
147
+ for _ in range(args.gnn_layers):
148
+ gats += [GatLinear(args.hidden_dim) if args.no_rel_attn else GatLinear_rel(args.hidden_dim)]
149
+ self.gather = nn.ModuleList(gats)
150
+ elif self.args.attn_type == 'dotprod':
151
+ gats = []
152
+ for _ in range(args.gnn_layers):
153
+ gats += [GatDot(args.hidden_dim) if args.no_rel_attn else GatDot_rel(args.hidden_dim)]
154
+ self.gather = nn.ModuleList(gats)
155
+ elif self.args.attn_type == 'rgcn':
156
+ gats = []
157
+ for _ in range(args.gnn_layers):
158
+ # gats += [GAT_dialoggcn(args.hidden_dim)]
159
+ gats += [GAT_dialoggcn_v1(args.hidden_dim)]
160
+ self.gather = nn.ModuleList(gats)
161
+
162
+ grus_c = []
163
+ for _ in range(args.gnn_layers):
164
+ grus_c += [nn.GRUCell(args.hidden_dim, args.hidden_dim)]
165
+ self.grus_c = nn.ModuleList(grus_c)
166
+
167
+ grus_p = []
168
+ for _ in range(args.gnn_layers):
169
+ grus_p += [nn.GRUCell(args.hidden_dim, args.hidden_dim)]
170
+ self.grus_p = nn.ModuleList(grus_p)
171
+
172
+ fcs = []
173
+ for _ in range(args.gnn_layers):
174
+ fcs += [nn.Linear(args.hidden_dim * 2, args.hidden_dim)]
175
+ self.fcs = nn.ModuleList(fcs)
176
+
177
+ self.fc1 = nn.Linear(args.emb_dim, args.hidden_dim)
178
+
179
+ self.nodal_att_type = args.nodal_att_type
180
+
181
+ in_dim = args.hidden_dim * (args.gnn_layers + 1) + args.emb_dim
182
+
183
+ # output mlp layers
184
+ layers = [nn.Linear(in_dim, args.hidden_dim), nn.ReLU()]
185
+ for _ in range(args.mlp_layers - 1):
186
+ layers += [nn.Linear(args.hidden_dim, args.hidden_dim), nn.ReLU()]
187
+ layers += [self.dropout]
188
+ layers += [nn.Linear(args.hidden_dim, num_class)]
189
+
190
+ self.out_mlp = nn.Sequential(*layers)
191
+
192
+ self.attentive_node_features = attentive_node_features(in_dim)
193
+
194
+ def forward(self, features, adj,s_mask,s_mask_onehot, lengths):
195
+ '''
196
+ :param features: (B, N, D)
197
+ :param adj: (B, N, N)
198
+ :param s_mask: (B, N, N)
199
+ :param s_mask_onehot: (B, N, N, 2)
200
+ :return:
201
+ '''
202
+ num_utter = features.size()[1]
203
+
204
+ H0 = F.relu(self.fc1(features))
205
+ # H0 = self.dropout(H0)
206
+ H = [H0]
207
+ for l in range(self.args.gnn_layers):
208
+ C = self.grus_c[l](H[l][:,0,:]).unsqueeze(1)
209
+ M = torch.zeros_like(C).squeeze(1)
210
+ # P = M.unsqueeze(1)
211
+ P = self.grus_p[l](M, H[l][:,0,:]).unsqueeze(1)
212
+ #H1 = F.relu(self.fcs[l](torch.cat((C,P) , dim = 2)))
213
+ #H1 = F.relu(C+P)
214
+ H1 = C+P
215
+ for i in range(1, num_utter):
216
+ # print(i,num_utter)
217
+ if self.args.attn_type == 'rgcn':
218
+ _, M = self.gather[l](H[l][:,i,:], H1, H1, adj[:,i,:i], s_mask[:,i,:i])
219
+ # _, M = self.gather[l](H[l][:,i,:], H1, H1, adj[:,i,:i], s_mask_onehot[:,i,:i,:])
220
+ else:
221
+ if not self.rel_attn:
222
+ _, M = self.gather[l](H[l][:,i,:], H1, H1, adj[:,i,:i])
223
+ else:
224
+ _, M = self.gather[l](H[l][:,i,:], H1, H1, adj[:,i,:i], s_mask[:, i, :i])
225
+
226
+ C = self.grus_c[l](H[l][:,i,:], M).unsqueeze(1)
227
+ P = self.grus_p[l](M, H[l][:,i,:]).unsqueeze(1)
228
+ # P = M.unsqueeze(1)
229
+ #H_temp = F.relu(self.fcs[l](torch.cat((C,P) , dim = 2)))
230
+ #H_temp = F.relu(C+P)
231
+ H_temp = C+P
232
+ H1 = torch.cat((H1 , H_temp), dim = 1)
233
+ # print('H1', H1.size())
234
+ # print('----------------------------------------------------')
235
+ H.append(H1)
236
+ H.append(features)
237
+
238
+ H = torch.cat(H, dim = 2)
239
+
240
+ H = self.attentive_node_features(H,lengths,self.nodal_att_type)
241
+
242
+ logits = self.out_mlp(H)
243
+
244
+ return logits
245
+
246
+
247
+ #仅仅使用最后一层的short和long,concat;只用过去特征
248
+ #Only use the final layer's short and long features, concatenated; use only past features.
249
+ class DAGERC_new_1(nn.Module):
250
+
251
+ def __init__(self, args, num_class):
252
+ super().__init__()
253
+ self.args = args
254
+ # gcn layer
255
+
256
+ self.dropout = nn.Dropout(args.dropout)
257
+
258
+ self.gnn_layers = args.gnn_layers
259
+
260
+ if not args.no_rel_attn:
261
+ self.rel_attn = True
262
+ else:
263
+ self.rel_attn = False
264
+
265
+ if self.args.attn_type == 'linear':
266
+ gats = []
267
+ for _ in range(args.gnn_layers):
268
+ gats += [GatLinear(args.hidden_dim) if args.no_rel_attn else GatLinear_rel(args.hidden_dim)]
269
+ self.gather = nn.ModuleList(gats)
270
+ elif self.args.attn_type == 'dotprod':
271
+ gats = []
272
+ for _ in range(args.gnn_layers):
273
+ gats += [GatDot(args.hidden_dim) if args.no_rel_attn else GatDot_rel(args.hidden_dim)]
274
+ self.gather = nn.ModuleList(gats)
275
+ elif self.args.attn_type == 'rgcn':
276
+ #短距离
277
+ gats_short = []
278
+ gats_long = []
279
+ for _ in range(args.gnn_layers):
280
+ gats_short += [GAT_dialoggcn_v1(args.hidden_dim)]
281
+ for _ in range(args.gnn_layers):
282
+ gats_long += [GAT_dialoggcn_v1(args.hidden_dim)]
283
+ self.gather_short = nn.ModuleList(gats_short)
284
+ self.gather_long = nn.ModuleList(gats_long)
285
+
286
+ # 近距离 GRU
287
+ grus_c_short = []
288
+ for _ in range(args.gnn_layers):
289
+ grus_c_short += [nn.GRUCell(args.hidden_dim, args.hidden_dim)]
290
+ self.grus_c_short = nn.ModuleList(grus_c_short)
291
+
292
+ # 远距离 GRU
293
+ grus_c_long = []
294
+ for _ in range(args.gnn_layers):
295
+ grus_c_long += [nn.GRUCell(args.hidden_dim, args.hidden_dim)]
296
+ self.grus_c_long = nn.ModuleList(grus_c_long)
297
+
298
+ grus_p_short = []
299
+ for _ in range(args.gnn_layers):
300
+ grus_p_short += [nn.GRUCell(args.hidden_dim, args.hidden_dim)]
301
+ self.grus_p_short = nn.ModuleList(grus_p_short)
302
+
303
+ grus_p_long = []
304
+ for _ in range(args.gnn_layers):
305
+ grus_p_long += [nn.GRUCell(args.hidden_dim, args.hidden_dim)]
306
+ self.grus_p_long = nn.ModuleList(grus_p_long)
307
+
308
+ #近距离全链接层
309
+ fcs_short = []
310
+ for _ in range(args.gnn_layers):
311
+ fcs_short += [nn.Linear(args.hidden_dim * 2, args.hidden_dim)]
312
+ self.fcs_short = nn.ModuleList(fcs_short)
313
+
314
+ # 远距离全连接层
315
+ fcs_long = []
316
+ for _ in range(args.gnn_layers):
317
+ fcs_long += [nn.Linear(args.hidden_dim * 2, args.hidden_dim)]
318
+ self.fcs_long = nn.ModuleList(fcs_long)
319
+
320
+
321
+ self.fc1 = nn.Linear(args.emb_dim, args.hidden_dim)
322
+
323
+ self.nodal_att_type = args.nodal_att_type
324
+
325
+ in_dim = ((args.hidden_dim*2)+ args.emb_dim)
326
+ # print(in_dim)
327
+ # output mlp layers
328
+ layers = [nn.Linear(in_dim, args.hidden_dim), nn.ReLU()]
329
+ for _ in range(args.mlp_layers - 1):
330
+ layers += [nn.Linear(args.hidden_dim, args.hidden_dim), nn.ReLU()]
331
+ layers += [self.dropout]
332
+ layers += [nn.Linear(args.hidden_dim, num_class)]
333
+
334
+ self.out_mlp = nn.Sequential(*layers)
335
+
336
+ self.attentive_node_features = attentive_node_features(in_dim)
337
+
338
+ self.affine1 = nn.Parameter(torch.empty(size=((args.hidden_dim) , (args.hidden_dim) )))
339
+ nn.init.xavier_uniform_(self.affine1.data, gain=1.414)
340
+ self.affine2 = nn.Parameter(torch.empty(size=((args.hidden_dim) , (args.hidden_dim) )))
341
+ nn.init.xavier_uniform_(self.affine2.data, gain=1.414)
342
+
343
+ self.diff_loss = DiffLoss(args)
344
+ self.beta = args.diffloss
345
+
346
+ def forward(self, features, adj_1, adj_2 ,s_mask, s_mask_onehot, lengths):
347
+ # 检查 H1 和 H2 是否完全相等
348
+ are_equal = all(torch.equal(h1, h2) for h1, h2 in zip(adj_1, adj_2))
349
+ # print("adj1 和 adj2 是否完全相等:", are_equal)
350
+ # print('adj1',adj_1)
351
+ # print('----------------------------------------------------')
352
+
353
+ # print('adj2',adj_2)
354
+ # print('----------------------------------------------------')
355
+
356
+ num_utter = features.size()[1]
357
+
358
+ H0 = F.relu(self.fc1(features))
359
+ #print('H0', H0.size())
360
+ # H0 = self.dropout(H0)
361
+ H = [H0]
362
+ H_combined_short_list = []
363
+ #对短距离特征进行处理
364
+ for l in range(self.args.gnn_layers):
365
+ C = self.grus_c_short[l](H[l][:,0,:]).unsqueeze(1) #针对每一层的第一个节点,使用 GRU 单元更新节点特征并聚合信息。
366
+ M = torch.zeros_like(C).squeeze(1) #初始化一个聚合信息张量 M(全零张量),并使用它与节点特征结合生成额外的特征 P。
367
+ # P = M.unsqueeze(1)
368
+ P = self.grus_p_short[l](M, H[l][:,0,:]).unsqueeze(1) #使用 M(全零张量)和第一个节点的特征 H[l][:, 0, :] 作为输入,得到额外特征 P,形状为 (B, D)
369
+ #H1 = F.relu(self.fcs[l](torch.cat((C,P) , dim = 2)))
370
+ #H1 = F.relu(C+P)
371
+ H1 = C+P#将更新后的特征 C 与额外特征 P 相加,生成新的节点特征 H1,为后续层的计算做准备。
372
+ for i in range(1, num_utter):
373
+ # print(i,num_utter)
374
+ if self.args.attn_type == 'rgcn':
375
+ #将 H[l][:, i, :](当前节点特征),H1(之前节点的特征聚合结果),adj[:, i, :i](当前节点与之前节点的邻接矩阵)
376
+ #s_mask[:, i, :i](当前节点的掩码),得到聚合结果 M
377
+ _, M = self.gather_short[l](H[l][:,i,:], H1, H1, adj_1[:,i,:i], s_mask[:,i,:i])
378
+ # _, M = self.gather[l](H[l][:,i,:], H1, H1, adj[:,i,:i], s_mask_onehot[:,i,:i,:])
379
+ else:
380
+ if not self.rel_attn:
381
+ _, M = self.gather_short[l](H[l][:,i,:], H1, H1, adj_1[:,i,:i])
382
+ else:
383
+ _, M = self.gather_short[l](H[l][:,i,:], H1, H1, adj_1[:,i,:i], s_mask[:, i, :i])
384
+
385
+ #使用 GRU 单元 self.grus_c[l] 来处理当前节点的特征 H[l][:, i, :] 和聚合后的特征 M,得到新的特��� C。
386
+ # 这表明当前节点的特征更新与其邻居的聚合信息有关。
387
+ C = self.grus_c_short[l](H[l][:,i,:], M).unsqueeze(1)
388
+ #使用另一个 GRU 单元 self.grus_p[l] 来处理聚合特征 M 和当前节点的特征 H[l][:, i, :],得到额外的特征 P。
389
+ P = self.grus_p_short[l](M, H[l][:,i,:]).unsqueeze(1)
390
+ # P = M.unsqueeze(1)
391
+ #H_temp = F.relu(self.fcs[l](torch.cat((C,P) , dim = 2)))
392
+ #H_temp = F.relu(C+P)
393
+ H_temp = C+P#将更新后的特征 C 和额外特征 P 进行相加,生成新的节点特征 H_temp
394
+ H1 = torch.cat((H1 , H_temp), dim = 1) #将当前节点的特征 H_temp 拼接到 H1 中。
395
+ # print('H1', H1.size())
396
+ #print('----------------------------------------------------')
397
+ H.append(H1)
398
+ H_combined_short_list.append(H[l+1])
399
+ '''
400
+ 下面对长距离特征进行处理 The following processes the long-distance features.
401
+ '''
402
+ H_long = [H0] # 初始化 H_long
403
+ H_combined_long_list = [] # 存储长距离处理的结果
404
+
405
+ # 对长距离特征进行处理
406
+ for l in range(self.args.gnn_layers):
407
+ C_long = self.grus_c_long[l](H_long[l][:,0,:]).unsqueeze(1) # 使用 GRU 更新长距离的第一个节点
408
+ M_long = torch.zeros_like(C_long).squeeze(1) # 初始化长距离的聚合信息张量 M_long
409
+ P_long = self.grus_p_long[l](M_long, H_long[l][:,0,:]).unsqueeze(1) # 生成额外的特征 P_long
410
+
411
+ H1_long = C_long + P_long # 生成新的长距离节点特征 H1_long
412
+ for i in range(1, num_utter):
413
+ # 依据不同的 attention 类型,进行特征聚合
414
+ if self.args.attn_type == 'rgcn':
415
+ _, M_long = self.gather_long[l](H_long[l][:,i,:], H1_long, H1_long, adj_2[:,i,:i], s_mask[:,i,:i])
416
+ else:
417
+ if not self.rel_attn:
418
+ _, M_long = self.gather_long[l](H_long[l][:,i,:], H1_long, H1_long, adj_2[:,i,:i])
419
+ else:
420
+ _, M_long = self.gather_long[l](H_long[l][:,i,:], H1_long, H1_long, adj_2[:,i,:i], s_mask[:,i,:i])
421
+
422
+ # 使用 GRU 更新当前节点的特征 C_long 和 M_long
423
+ C_long = self.grus_c_long[l](H_long[l][:,i,:], M_long).unsqueeze(1)
424
+ P_long = self.grus_p_long[l](M_long, H_long[l][:,i,:]).unsqueeze(1)
425
+
426
+ H_temp_long = C_long + P_long # 将更新后的特征 C_long 和 P_long 相加生成新特征
427
+ H1_long = torch.cat((H1_long, H_temp_long), dim=1) # 将特征拼接到 H1_long 中
428
+ H_long.append(H1_long) # 更新 H_long 列表
429
+ H_combined_long_list.append(H_long[l+1])
430
+
431
+ '''
432
+ 两个通道特征都提取完毕! Both short- and long-distance channel features have been extracted!
433
+ '''
434
+ # print('H_combined_short_list',H_combined_short_list)
435
+ # print('H_combined_long_list',H_combined_long_list)
436
+ # are_equal = all(torch.equal(h1, h2) for h1, h2 in zip(H_combined_short_list, H_combined_long_list))
437
+ # print("H_combined_short_list 和 H_combined_long_list 是否完全相等:", are_equal)
438
+ # for idx, tensor in enumerate(H_combined_short_list):
439
+ # print(f"H_combined_short_list[{idx}] shape: {tensor.shape}")
440
+ H_final = []
441
+ # print("H2 shape:", H2.shape)
442
+ # 计算差异正则化损失
443
+ diff_loss = 0
444
+ for l in range(self.args.gnn_layers):
445
+ # print('周期:', l)
446
+ HShort_prime = H_combined_short_list[l]
447
+ HLong_prime = H_combined_long_list[l]
448
+ # print("HShort_prime:", HShort_prime)
449
+ # print("HLong_prime:", HLong_prime)
450
+ # print("HShort_prime shape:", HShort_prime.shape)
451
+ # print("HLong_prime shape:", HLong_prime.shape)
452
+ diff_loss = self.diff_loss(HShort_prime, HLong_prime) + diff_loss
453
+ # print("diff_loss:", diff_loss)
454
+ # print(diff_loss.item())
455
+ # 互交叉注意力机制
456
+ A1 = F.softmax(torch.bmm(torch.matmul(HShort_prime, self.affine1), torch.transpose(HLong_prime, 1, 2)), dim=2)
457
+ A2 = F.softmax(torch.bmm(torch.matmul(HLong_prime, self.affine2), torch.transpose(HShort_prime, 1, 2)), dim=2)
458
+
459
+ HShort_prime_new = torch.bmm(A1, HLong_prime) # 更新的短时特征
460
+ HLong_prime_new = torch.bmm(A2, HShort_prime) # 更新的长时特征
461
+
462
+ HShort_prime_out = self.dropout(HShort_prime_new) if l < self.args.gnn_layers - 1 else HShort_prime_new
463
+ HLong_prime_out = self.dropout(HLong_prime_new) if l <self.args.gnn_layers - 1 else HLong_prime_new
464
+
465
+ H_final.append(HShort_prime_out)
466
+ H_final.append(HLong_prime_out)
467
+ H_final.append(features)
468
+
469
+ H_final = torch.cat([H_final[-3],H_final[-2],H_final[-1]], dim = 2)
470
+ # print("H shape:", H.shape)
471
+ # print("H:", H.shape)
472
+ # print("H_final shape after cat:", H_final.shape)
473
+ H_final = self.attentive_node_features(H_final,lengths,self.nodal_att_type)
474
+ # print("H_final shape after attentive_node_features:", H_final.shape)
475
+ logits = self.out_mlp(H_final)
476
+ # print(diff_loss)
477
+ return logits, self.beta * diff_loss
478
+
479
+
480
+ #仅仅使用最后一层的short和long,concat;使用了过去和未来双特征
481
+ #Only the final-layer short and long features are used and concatenated; both past and future features are utilized.
482
+ class DAGERC_new_2(nn.Module):
483
+
484
+ def __init__(self, args, num_class):
485
+ super().__init__()
486
+ self.args = args
487
+ # gcn layer
488
+
489
+ self.dropout = nn.Dropout(args.dropout)
490
+
491
+ self.gnn_layers = args.gnn_layers
492
+
493
+ if not args.no_rel_attn:
494
+ self.rel_attn = True
495
+ else:
496
+ self.rel_attn = False
497
+
498
+ if self.args.attn_type == 'linear':
499
+ gats = []
500
+ for _ in range(args.gnn_layers):
501
+ gats += [GatLinear(args.hidden_dim) if args.no_rel_attn else GatLinear_rel(args.hidden_dim)]
502
+ self.gather = nn.ModuleList(gats)
503
+ elif self.args.attn_type == 'dotprod':
504
+ gats = []
505
+ for _ in range(args.gnn_layers):
506
+ gats += [GatDot(args.hidden_dim) if args.no_rel_attn else GatDot_rel(args.hidden_dim)]
507
+ self.gather = nn.ModuleList(gats)
508
+ elif self.args.attn_type == 'rgcn':
509
+ #短距离
510
+ gats_short = []
511
+ gats_long = []
512
+ for _ in range(args.gnn_layers):
513
+ gats_short += [GAT_dialoggcn_v1(args.hidden_dim)]
514
+ for _ in range(args.gnn_layers):
515
+ gats_long += [GAT_dialoggcn_v1(args.hidden_dim)]
516
+ self.gather_short = nn.ModuleList(gats_short)
517
+ self.gather_long = nn.ModuleList(gats_long)
518
+
519
+ # 近距离 GRU
520
+ grus_c_short = []
521
+ for _ in range(args.gnn_layers):
522
+ grus_c_short += [nn.GRUCell(args.hidden_dim, args.hidden_dim)]
523
+ self.grus_c_short = nn.ModuleList(grus_c_short)
524
+
525
+ # 远距离 GRU
526
+ grus_c_long = []
527
+ for _ in range(args.gnn_layers):
528
+ grus_c_long += [nn.GRUCell(args.hidden_dim, args.hidden_dim)]
529
+ self.grus_c_long = nn.ModuleList(grus_c_long)
530
+
531
+ grus_p_short = []
532
+ for _ in range(args.gnn_layers):
533
+ grus_p_short += [nn.GRUCell(args.hidden_dim, args.hidden_dim)]
534
+ self.grus_p_short = nn.ModuleList(grus_p_short)
535
+
536
+ grus_p_long = []
537
+ for _ in range(args.gnn_layers):
538
+ grus_p_long += [nn.GRUCell(args.hidden_dim, args.hidden_dim)]
539
+ self.grus_p_long = nn.ModuleList(grus_p_long)
540
+
541
+ #近距离全链接层
542
+ fcs_short = []
543
+ for _ in range(args.gnn_layers):
544
+ fcs_short += [nn.Linear(args.hidden_dim * 2, args.hidden_dim)]
545
+ self.fcs_short = nn.ModuleList(fcs_short)
546
+
547
+ # 远距离全连接层
548
+ fcs_long = []
549
+ for _ in range(args.gnn_layers):
550
+ fcs_long += [nn.Linear(args.hidden_dim * 2, args.hidden_dim)]
551
+ self.fcs_long = nn.ModuleList(fcs_long)
552
+
553
+
554
+ self.fc1 = nn.Linear(args.emb_dim, args.hidden_dim)
555
+
556
+ self.nodal_att_type = args.nodal_att_type
557
+
558
+ in_dim = ((args.hidden_dim*2)*2 + args.emb_dim)
559
+ # print(in_dim)
560
+ # output mlp layers
561
+ layers = [nn.Linear(in_dim, args.hidden_dim), nn.ReLU()]
562
+ for _ in range(args.mlp_layers - 1):
563
+ layers += [nn.Linear(args.hidden_dim, args.hidden_dim), nn.ReLU()]
564
+ layers += [self.dropout]
565
+ layers += [nn.Linear(args.hidden_dim, num_class)]
566
+
567
+ self.out_mlp = nn.Sequential(*layers)
568
+
569
+ self.attentive_node_features = attentive_node_features(in_dim)
570
+
571
+ self.affine1 = nn.Parameter(torch.empty(size=((args.hidden_dim*2) , (args.hidden_dim*2) )))
572
+ nn.init.xavier_uniform_(self.affine1.data, gain=1.414)
573
+ self.affine2 = nn.Parameter(torch.empty(size=((args.hidden_dim*2) , (args.hidden_dim*2) )))
574
+ nn.init.xavier_uniform_(self.affine2.data, gain=1.414)
575
+
576
+ self.diff_loss = DiffLoss(args)
577
+ self.beta = args.diffloss
578
+
579
+ def forward(self, features, adj_1, adj_2 ,s_mask, s_mask_onehot, lengths):
580
+ # 检查 H1 和 H2 是否完全相等
581
+ are_equal = all(torch.equal(h1, h2) for h1, h2 in zip(adj_1, adj_2))
582
+ # print("adj1 和 adj2 是否完全相等:", are_equal)
583
+ # print('adj1',adj_1)
584
+ # print('----------------------------------------------------')
585
+
586
+ # print('adj2',adj_2)
587
+ # print('----------------------------------------------------')
588
+
589
+ num_utter = features.size()[1]
590
+
591
+ H0 = F.relu(self.fc1(features))
592
+ #print('H0', H0.size())
593
+ # H0 = self.dropout(H0)
594
+ H = [H0]
595
+ H_combined_short_list = []
596
+ #对短距离特征进行处理
597
+ for l in range(self.args.gnn_layers):
598
+ C = self.grus_c_short[l](H[l][:,0,:]).unsqueeze(1) #针对每一层的第一个节点,使用 GRU 单元更新节点特征并聚合信息。
599
+ M = torch.zeros_like(C).squeeze(1) #初始化一个聚合信息张量 M(全零张量),并使用它与节点特征结合生成额外的特征 P。
600
+ # P = M.unsqueeze(1)
601
+ P = self.grus_p_short[l](M, H[l][:,0,:]).unsqueeze(1) #使用 M(全零张量)和第一个节点的特征 H[l][:, 0, :] 作为输入,得到额外特征 P,形状为 (B, D)
602
+ #H1 = F.relu(self.fcs[l](torch.cat((C,P) , dim = 2)))
603
+ #H1 = F.relu(C+P)
604
+ H1 = C+P#将更新后的特征 C 与额外特征 P 相加,生成新的节点特征 H1,为后续层的计算做准备。
605
+ for i in range(1, num_utter):
606
+ # print(i,num_utter)
607
+ if self.args.attn_type == 'rgcn':
608
+ #将 H[l][:, i, :](当前节点特征),H1(之前节点的特征聚合结果),adj[:, i, :i](当前节点与之前节点的邻接矩阵)
609
+ #s_mask[:, i, :i](当前节点的掩码),得到聚合结果 M
610
+ _, M = self.gather_short[l](H[l][:,i,:], H1, H1, adj_1[:,i,:i], s_mask[:,i,:i])
611
+ # _, M = self.gather[l](H[l][:,i,:], H1, H1, adj[:,i,:i], s_mask_onehot[:,i,:i,:])
612
+ else:
613
+ if not self.rel_attn:
614
+ _, M = self.gather_short[l](H[l][:,i,:], H1, H1, adj_1[:,i,:i])
615
+ else:
616
+ _, M = self.gather_short[l](H[l][:,i,:], H1, H1, adj_1[:,i,:i], s_mask[:, i, :i])
617
+
618
+ #使用 GRU 单元 self.grus_c[l] 来处理当前节点的特征 H[l][:, i, :] 和聚合后的特征 M,得到新的特征 C。
619
+ # 这表明当前节点的特征更新与其邻居的聚合信息有关。
620
+ C = self.grus_c_short[l](H[l][:,i,:], M).unsqueeze(1)
621
+ #使用另一个 GRU 单元 self.grus_p[l] 来处理聚合特征 M 和当前节点的特征 H[l][:, i, :],得到额外的特征 P。
622
+ P = self.grus_p_short[l](M, H[l][:,i,:]).unsqueeze(1)
623
+ # P = M.unsqueeze(1)
624
+ #H_temp = F.relu(self.fcs[l](torch.cat((C,P) , dim = 2)))
625
+ #H_temp = F.relu(C+P)
626
+ H_temp = C+P#将更新后的特征 C 和额外特征 P 进行相加,生成新的节点特征 H_temp
627
+ H1 = torch.cat((H1 , H_temp), dim = 1) #将当前节点的特征 H_temp 拼接到 H1 中。
628
+ # print('H1', H1.size())
629
+ #print('----------------------------------------------------')
630
+ H.append(H1)
631
+
632
+ # 将输入特征反转
633
+ # 反向特征提取
634
+ features_reversed = torch.flip(features, dims=[1]) # 反转特征顺序
635
+ adj_reversed = torch.flip(adj_1, dims=[1, 2]) # 反转邻接矩阵
636
+ s_mask_reversed = torch.flip(s_mask, dims=[1, 2]) # 反转掩码
637
+
638
+ H0_reversed = F.relu(self.fc1(features_reversed))
639
+ H_reversed = [H0_reversed]
640
+
641
+ for l in range(self.args.gnn_layers):
642
+ C = self.grus_c_short[l](H_reversed[l][:, 0, :]).unsqueeze(1)
643
+ M = torch.zeros_like(C).squeeze(1)
644
+ P = self.grus_p_short[l](M, H_reversed[l][:, 0, :]).unsqueeze(1)
645
+ H1_reversed = C + P
646
+
647
+ for i in range(1, num_utter):
648
+ if self.args.attn_type == 'rgcn':
649
+ _, M = self.gather_short[l](H_reversed[l][:, i, :], H1_reversed, H1_reversed, adj_reversed[:, i, :i], s_mask_reversed[:, i, :i])
650
+ else:
651
+ if not self.rel_attn:
652
+ _, M = self.gather_short[l](H_reversed[l][:, i, :], H1_reversed, H1_reversed, adj_reversed[:, i, :i])
653
+ else:
654
+ _, M = self.gather_short[l](H_reversed[l][:, i, :], H1_reversed, H1_reversed, adj_reversed[:, i, :i], s_mask_reversed[:, i, :i])
655
+
656
+ C = self.grus_c_short[l](H_reversed[l][:, i, :], M).unsqueeze(1)
657
+ P = self.grus_p_short[l](M, H_reversed[l][:, i, :]).unsqueeze(1)
658
+ H_temp_reversed = C + P
659
+ H1_reversed = torch.cat((H1_reversed, H_temp_reversed), dim=1)
660
+ H_reversed.append(H1_reversed)
661
+ H_combined = torch.cat((H[l+1], H_reversed[l+1]), dim=2) # 在第二维度拼接
662
+ H_combined_short_list.append(H_combined) # 将拼接后的结果添加到新列表中
663
+
664
+ '''
665
+ 下面对长距离特征进行处理 The following processes the long-distance features.
666
+ '''
667
+ H_long = [H0] # 初始化 H_long
668
+ H_combined_long_list = [] # 存储长距离处理的结果
669
+
670
+ # 对长距离特征进行处理
671
+ for l in range(self.args.gnn_layers):
672
+ C_long = self.grus_c_long[l](H_long[l][:,0,:]).unsqueeze(1) # 使用 GRU 更新长距离的第一个��点
673
+ M_long = torch.zeros_like(C_long).squeeze(1) # 初始化长距离的聚合信息张量 M_long
674
+ P_long = self.grus_p_long[l](M_long, H_long[l][:,0,:]).unsqueeze(1) # 生成额外的特征 P_long
675
+
676
+ H1_long = C_long + P_long # 生成新的长距离节点特征 H1_long
677
+ for i in range(1, num_utter):
678
+ # 依据不同的 attention 类型,进行特征聚合
679
+ if self.args.attn_type == 'rgcn':
680
+ _, M_long = self.gather_long[l](H_long[l][:,i,:], H1_long, H1_long, adj_2[:,i,:i], s_mask[:,i,:i])
681
+ else:
682
+ if not self.rel_attn:
683
+ _, M_long = self.gather_long[l](H_long[l][:,i,:], H1_long, H1_long, adj_2[:,i,:i])
684
+ else:
685
+ _, M_long = self.gather_long[l](H_long[l][:,i,:], H1_long, H1_long, adj_2[:,i,:i], s_mask[:,i,:i])
686
+
687
+ # 使用 GRU 更新当前节点的特征 C_long 和 M_long
688
+ C_long = self.grus_c_long[l](H_long[l][:,i,:], M_long).unsqueeze(1)
689
+ P_long = self.grus_p_long[l](M_long, H_long[l][:,i,:]).unsqueeze(1)
690
+
691
+ H_temp_long = C_long + P_long # 将更新后的特征 C_long 和 P_long 相加生成新特征
692
+ H1_long = torch.cat((H1_long, H_temp_long), dim=1) # 将特征拼接到 H1_long 中
693
+ H_long.append(H1_long) # 更新 H_long 列表
694
+
695
+ # 反转特征顺序,进行逆向长距离特征提取
696
+ features_reversed_long = torch.flip(features, dims=[1]) # 反转特征顺序
697
+ adj_reversed_long = torch.flip(adj_2, dims=[1, 2]) # 反转长距离邻接矩阵
698
+ s_mask_reversed_long = torch.flip(s_mask, dims=[1, 2]) # 反转掩码
699
+
700
+ H0_reversed_long = F.relu(self.fc1(features_reversed_long))
701
+ H_reversed_long = [H0_reversed_long]
702
+
703
+ for l in range(self.args.gnn_layers):
704
+ C_long = self.grus_c_long[l](H_reversed_long[l][:, 0, :]).unsqueeze(1)
705
+ M_long = torch.zeros_like(C_long).squeeze(1)
706
+ P_long = self.grus_p_long[l](M_long, H_reversed_long[l][:, 0, :]).unsqueeze(1)
707
+ H1_reversed_long = C_long + P_long
708
+
709
+ for i in range(1, num_utter):
710
+ if self.args.attn_type == 'rgcn':
711
+ _, M_long = self.gather_long[l](H_reversed_long[l][:, i, :], H1_reversed_long, H1_reversed_long, adj_reversed_long[:, i, :i], s_mask_reversed_long[:, i, :i])
712
+ else:
713
+ if not self.rel_attn:
714
+ _, M_long = self.gather_long[l](H_reversed_long[l][:, i, :], H1_reversed_long, H1_reversed_long, adj_reversed_long[:, i, :i])
715
+ else:
716
+ _, M_long = self.gather_long[l](H_reversed_long[l][:, i, :], H1_reversed_long, H1_reversed_long, adj_reversed_long[:, i, :i], s_mask_reversed_long[:, i, :i])
717
+
718
+ C_long = self.grus_c_long[l](H_reversed_long[l][:, i, :], M_long).unsqueeze(1)
719
+ P_long = self.grus_p_long[l](M_long, H_reversed_long[l][:, i, :]).unsqueeze(1)
720
+ H_temp_reversed_long = C_long + P_long
721
+ H1_reversed_long = torch.cat((H1_reversed_long, H_temp_reversed_long), dim=1)
722
+ H_reversed_long.append(H1_reversed_long)
723
+
724
+ # 将正向和逆向的长距离特征进行拼接
725
+ H_combined_long = torch.cat((H_long[l+1], H_reversed_long[l+1]), dim=2)
726
+ H_combined_long_list.append(H_combined_long)
727
+
728
+ '''
729
+ 两个通道特征都提取完毕! Both short- and long-distance channel features have been extracted!
730
+ '''
731
+ # print('H_combined_short_list',H_combined_short_list)
732
+ # print('H_combined_long_list',H_combined_long_list)
733
+ # are_equal = all(torch.equal(h1, h2) for h1, h2 in zip(H_combined_short_list, H_combined_long_list))
734
+ # print("H_combined_short_list 和 H_combined_long_list 是否完全相等:", are_equal)
735
+ # for idx, tensor in enumerate(H_combined_short_list):
736
+ # print(f"H_combined_short_list[{idx}] shape: {tensor.shape}")
737
+ H_final = []
738
+ # print("H2 shape:", H2.shape)
739
+ # 计算差异正则化损失
740
+ diff_loss = 0
741
+ for l in range(self.args.gnn_layers):
742
+ # print('周期:', l)
743
+ HShort_prime = H_combined_short_list[l]
744
+ HLong_prime = H_combined_long_list[l]
745
+ print("HShort_prime:", HShort_prime)
746
+ print("HLong_prime:", HLong_prime)
747
+ print("HShort_prime shape:", HShort_prime.shape)
748
+ print("HLong_prime shape:", HLong_prime.shape)
749
+ diff_loss = self.diff_loss(HShort_prime, HLong_prime) + diff_loss
750
+ # print("diff_loss:", diff_loss)
751
+ # print(diff_loss.item())
752
+ # 互交叉注意力机制
753
+ A1 = F.softmax(torch.bmm(torch.matmul(HShort_prime, self.affine1), torch.transpose(HLong_prime, 1, 2)), dim=2)
754
+ A2 = F.softmax(torch.bmm(torch.matmul(HLong_prime, self.affine2), torch.transpose(HShort_prime, 1, 2)), dim=2)
755
+
756
+ HShort_prime_new = torch.bmm(A1, HLong_prime) # 更新的短时特征
757
+ HLong_prime_new = torch.bmm(A2, HShort_prime) # 更新的长时特征
758
+
759
+ HShort_prime_out = self.dropout(HShort_prime_new) if l < self.args.gnn_layers - 1 else HShort_prime_new
760
+ HLong_prime_out = self.dropout(HLong_prime_new) if l <self.args.gnn_layers - 1 else HLong_prime_new
761
+
762
+ H_final.append(HShort_prime_out)
763
+ H_final.append(HLong_prime_out)
764
+ H_final.append(features)
765
+
766
+ H_final = torch.cat([H_final[-3],H_final[-2],H_final[-1]], dim = 2)
767
+ # print("H shape:", H.shape)
768
+ # print("H:", H.shape)
769
+ # print("H_final shape after cat:", H_final.shape)
770
+ H_final = self.attentive_node_features(H_final,lengths,self.nodal_att_type)
771
+ # print("H_final shape after attentive_node_features:", H_final.shape)
772
+ logits = self.out_mlp(H_final)
773
+ # print(diff_loss)
774
+ return logits, self.beta * diff_loss
775
+
776
+
777
+ #使用所有层的short和long,使用sum加每一层,不使用双特征融合技术
778
+ #All-layer short and long features are used, with a sum over each layer; dual-feature fusion is not applied.
779
+ class DAGERC_new_3(nn.Module):
780
+
781
+ def __init__(self, args, num_class):
782
+ super().__init__()
783
+ self.args = args
784
+ # gcn layer
785
+
786
+ self.dropout = nn.Dropout(args.dropout)
787
+
788
+ self.gnn_layers = args.gnn_layers
789
+
790
+ if not args.no_rel_attn:
791
+ self.rel_attn = True
792
+ else:
793
+ self.rel_attn = False
794
+
795
+ if self.args.attn_type == 'linear':
796
+ gats = []
797
+ for _ in range(args.gnn_layers):
798
+ gats += [GatLinear(args.hidden_dim) if args.no_rel_attn else GatLinear_rel(args.hidden_dim)]
799
+ self.gather = nn.ModuleList(gats)
800
+ elif self.args.attn_type == 'dotprod':
801
+ gats = []
802
+ for _ in range(args.gnn_layers):
803
+ gats += [GatDot(args.hidden_dim) if args.no_rel_attn else GatDot_rel(args.hidden_dim)]
804
+ self.gather = nn.ModuleList(gats)
805
+ elif self.args.attn_type == 'rgcn':
806
+ #短距离
807
+ gats_short = []
808
+ gats_long = []
809
+ for _ in range(args.gnn_layers):
810
+ gats_short += [GAT_dialoggcn_v1(args.hidden_dim)]
811
+ for _ in range(args.gnn_layers):
812
+ gats_long += [GAT_dialoggcn_v1(args.hidden_dim)]
813
+ self.gather_short = nn.ModuleList(gats_short)
814
+ self.gather_long = nn.ModuleList(gats_long)
815
+
816
+ # 近距离 GRU
817
+ grus_c_short = []
818
+ for _ in range(args.gnn_layers):
819
+ grus_c_short += [nn.GRUCell(args.hidden_dim, args.hidden_dim)]
820
+ self.grus_c_short = nn.ModuleList(grus_c_short)
821
+
822
+ # 远距离 GRU
823
+ grus_c_long = []
824
+ for _ in range(args.gnn_layers):
825
+ grus_c_long += [nn.GRUCell(args.hidden_dim, args.hidden_dim)]
826
+ self.grus_c_long = nn.ModuleList(grus_c_long)
827
+
828
+ grus_p_short = []
829
+ for _ in range(args.gnn_layers):
830
+ grus_p_short += [nn.GRUCell(args.hidden_dim, args.hidden_dim)]
831
+ self.grus_p_short = nn.ModuleList(grus_p_short)
832
+
833
+ grus_p_long = []
834
+ for _ in range(args.gnn_layers):
835
+ grus_p_long += [nn.GRUCell(args.hidden_dim, args.hidden_dim)]
836
+ self.grus_p_long = nn.ModuleList(grus_p_long)
837
+
838
+ #近距离全链接层
839
+ fcs_short = []
840
+ for _ in range(args.gnn_layers):
841
+ fcs_short += [nn.Linear(args.hidden_dim * 2, args.hidden_dim)]
842
+ self.fcs_short = nn.ModuleList(fcs_short)
843
+
844
+ # 远距离全连接层
845
+ fcs_long = []
846
+ for _ in range(args.gnn_layers):
847
+ fcs_long += [nn.Linear(args.hidden_dim * 2, args.hidden_dim)]
848
+ self.fcs_long = nn.ModuleList(fcs_long)
849
+
850
+
851
+ self.fc1 = nn.Linear(args.emb_dim, args.hidden_dim)
852
+
853
+ self.nodal_att_type = args.nodal_att_type
854
+
855
+ in_dim = (args.hidden_dim * (args.gnn_layers + 1)) + args.emb_dim
856
+ # print(in_dim)
857
+ # output mlp layers
858
+ layers = [nn.Linear(in_dim, args.hidden_dim), nn.ReLU()]
859
+ for _ in range(args.mlp_layers - 1):
860
+ layers += [nn.Linear(args.hidden_dim, args.hidden_dim), nn.ReLU()]
861
+ layers += [self.dropout]
862
+ layers += [nn.Linear(args.hidden_dim, num_class)]
863
+
864
+ self.out_mlp = nn.Sequential(*layers)
865
+
866
+ self.attentive_node_features = attentive_node_features(in_dim)
867
+
868
+ def forward(self, features, adj_1, adj_2 ,s_mask, s_mask_onehot, lengths):
869
+ # 检查 H1 和 H2 是否完全相等
870
+ are_equal = all(torch.equal(h1, h2) for h1, h2 in zip(adj_1, adj_2))
871
+ # print("adj1 和 adj2 是否完全相等:", are_equal)
872
+ # print('adj1',adj_1)
873
+ # print('----------------------------------------------------')
874
+
875
+ # print('adj2',adj_2)
876
+ # print('----------------------------------------------------')
877
+
878
+ num_utter = features.size()[1]
879
+
880
+ H0 = F.relu(self.fc1(features))
881
+ #print('H0', H0.size())
882
+ # H0 = self.dropout(H0)
883
+ H = [H0]
884
+ H_combined_short_list = []
885
+ #对短距离特征进行处理
886
+ for l in range(self.args.gnn_layers):
887
+ C = self.grus_c_short[l](H[l][:,0,:]).unsqueeze(1) #针对每一层的第一个节点,使用 GRU 单元更新节点特征并聚合信息。
888
+ M = torch.zeros_like(C).squeeze(1) #初始化一个聚合信息张量 M(全零张量),并使用它与节点特征结合生成额外的特征 P。
889
+ # P = M.unsqueeze(1)
890
+ P = self.grus_p_short[l](M, H[l][:,0,:]).unsqueeze(1) #使用 M(全零张量)和第一个节点的特征 H[l][:, 0, :] 作为输入,得到额外特征 P,形状为 (B, D)
891
+ #H1 = F.relu(self.fcs[l](torch.cat((C,P) , dim = 2)))
892
+ #H1 = F.relu(C+P)
893
+ H1 = C+P#将更新后的特征 C 与额外特征 P 相加,生成新的节点特征 H1,为后续层的计算做准备。
894
+ for i in range(1, num_utter):
895
+ # print(i,num_utter)
896
+ if self.args.attn_type == 'rgcn':
897
+ #将 H[l][:, i, :](当前节点特征),H1(之前节点的特征聚合结果),adj[:, i, :i](当前节点与之前节点的邻接矩阵)
898
+ #s_mask[:, i, :i](当前节点的掩码),得到聚合结果 M
899
+ _, M = self.gather_short[l](H[l][:,i,:], H1, H1, adj_1[:,i,:i], s_mask[:,i,:i])
900
+ # _, M = self.gather[l](H[l][:,i,:], H1, H1, adj[:,i,:i], s_mask_onehot[:,i,:i,:])
901
+ else:
902
+ if not self.rel_attn:
903
+ _, M = self.gather_short[l](H[l][:,i,:], H1, H1, adj_1[:,i,:i])
904
+ else:
905
+ _, M = self.gather_short[l](H[l][:,i,:], H1, H1, adj_1[:,i,:i], s_mask[:, i, :i])
906
+
907
+ #使用 GRU 单元 self.grus_c[l] 来处理当前节点的特征 H[l][:, i, :] 和聚合后的特征 M,得到新的特征 C。
908
+ # 这表明当前节点的特征更新与其邻居的聚合信息有关。
909
+ C = self.grus_c_short[l](H[l][:,i,:], M).unsqueeze(1)
910
+ #使用另一个 GRU 单元 self.grus_p[l] 来处理聚合特征 M 和当前节点的特征 H[l][:, i, :],得到额外的特征 P。
911
+ P = self.grus_p_short[l](M, H[l][:,i,:]).unsqueeze(1)
912
+ # P = M.unsqueeze(1)
913
+ #H_temp = F.relu(self.fcs[l](torch.cat((C,P) , dim = 2)))
914
+ #H_temp = F.relu(C+P)
915
+ H_temp = C+P#将更新后的特征 C 和额外特征 P 进行相加,生成新的节点特征 H_temp
916
+ H1 = torch.cat((H1 , H_temp), dim = 1) #将当前节点的特征 H_temp 拼接到 H1 中。
917
+ # print('H1', H1.size())
918
+ #print('----------------------------------------------------')
919
+ H.append(H1)
920
+ '''
921
+ 下面对长距离特征进行处理
922
+ '''
923
+ H_long = [H0] # 初始化 H_long
924
+ H_combined_long_list = [] # 存储长距离处理的结果
925
+
926
+ # 对长距离特征进行处理
927
+ for l in range(self.args.gnn_layers):
928
+ C_long = self.grus_c_long[l](H_long[l][:,0,:]).unsqueeze(1) # 使用 GRU 更新长距离的第一个节点
929
+ M_long = torch.zeros_like(C_long).squeeze(1) # 初始化长距离的聚合信息张量 M_long
930
+ P_long = self.grus_p_long[l](M_long, H_long[l][:,0,:]).unsqueeze(1) # 生成额外的特征 P_long
931
+
932
+ H1_long = C_long + P_long # 生成新的长距离节点特征 H1_long
933
+ for i in range(1, num_utter):
934
+ # 依据不同的 attention 类型,进行特征聚合
935
+ if self.args.attn_type == 'rgcn':
936
+ _, M_long = self.gather_long[l](H_long[l][:,i,:], H1_long, H1_long, adj_2[:,i,:i], s_mask[:,i,:i])
937
+ else:
938
+ if not self.rel_attn:
939
+ _, M_long = self.gather_long[l](H_long[l][:,i,:], H1_long, H1_long, adj_2[:,i,:i])
940
+ else:
941
+ _, M_long = self.gather_long[l](H_long[l][:,i,:], H1_long, H1_long, adj_2[:,i,:i], s_mask[:,i,:i])
942
+
943
+ # 使用 GRU 更新当前节点的特征 C_long 和 M_long
944
+ C_long = self.grus_c_long[l](H_long[l][:,i,:], M_long).unsqueeze(1)
945
+ P_long = self.grus_p_long[l](M_long, H_long[l][:,i,:]).unsqueeze(1)
946
+
947
+ H_temp_long = C_long + P_long # 将更新后的特征 C_long 和 P_long 相加生成新特征
948
+ H1_long = torch.cat((H1_long, H_temp_long), dim=1) # 将特征拼接到 H1_long 中
949
+ H_long.append(H1_long) # 更新 H_long 列表
950
+ # for i, h in enumerate(H):
951
+ # print(f"H[{i}] shape: {h.shape}")
952
+
953
+ H_combined = torch.cat(H, dim=2)
954
+ H_long_combined = torch.cat(H_long, dim=2)
955
+ sum_features = H_combined + H_long_combined
956
+ # print('sum_features Shape:', sum_features.shape)
957
+ # print('features Shape:', features.shape)
958
+ H_combined_final = torch.cat((sum_features, features), dim=2)
959
+
960
+ H_final = self.attentive_node_features(H_combined_final,lengths,self.nodal_att_type)
961
+ # print("H_final shape after attentive_node_features:", H_final.shape)
962
+ logits = self.out_mlp(H_final)
963
+ # print(diff_loss)
964
+ return logits
965
+
966
+ #使用过去的所有层的short和long,每一层都concat,使用特征融合技术。
967
+ #All past-layer short and long features are used; features from each layer are concatenated, and feature fusion techniques are applied.
968
+ class DAGERC_new_4(nn.Module):
969
+
970
+ def __init__(self, args, num_class):
971
+ super().__init__()
972
+ self.args = args
973
+ # gcn layer
974
+
975
+ self.dropout = nn.Dropout(args.dropout)
976
+
977
+ self.gnn_layers = args.gnn_layers
978
+
979
+ if not args.no_rel_attn:
980
+ self.rel_attn = True
981
+ else:
982
+ self.rel_attn = False
983
+
984
+ if self.args.attn_type == 'linear':
985
+ gats = []
986
+ for _ in range(args.gnn_layers):
987
+ gats += [GatLinear(args.hidden_dim) if args.no_rel_attn else GatLinear_rel(args.hidden_dim)]
988
+ self.gather = nn.ModuleList(gats)
989
+ elif self.args.attn_type == 'dotprod':
990
+ gats = []
991
+ for _ in range(args.gnn_layers):
992
+ gats += [GatDot(args.hidden_dim) if args.no_rel_attn else GatDot_rel(args.hidden_dim)]
993
+ self.gather = nn.ModuleList(gats)
994
+ elif self.args.attn_type == 'rgcn':
995
+ gats_short = []
996
+ gats_long = []
997
+ for _ in range(args.gnn_layers):
998
+ gats_short += [GAT_dialoggcn_v1(args.hidden_dim)]
999
+ for _ in range(args.gnn_layers):
1000
+ gats_long += [GAT_dialoggcn_v1(args.hidden_dim)]
1001
+ self.gather_short = nn.ModuleList(gats_short)
1002
+ self.gather_long = nn.ModuleList(gats_long)
1003
+
1004
+ # 近距离 GRU
1005
+ # short distance GRU
1006
+ grus_c_short = []
1007
+ for _ in range(args.gnn_layers):
1008
+ grus_c_short += [nn.GRUCell(args.hidden_dim, args.hidden_dim)]
1009
+ self.grus_c_short = nn.ModuleList(grus_c_short)
1010
+
1011
+ # 远距离 GRU
1012
+ # long distance GRU
1013
+ grus_c_long = []
1014
+ for _ in range(args.gnn_layers):
1015
+ grus_c_long += [nn.GRUCell(args.hidden_dim, args.hidden_dim)]
1016
+ self.grus_c_long = nn.ModuleList(grus_c_long)
1017
+
1018
+ grus_p_short = []
1019
+ for _ in range(args.gnn_layers):
1020
+ grus_p_short += [nn.GRUCell(args.hidden_dim, args.hidden_dim)]
1021
+ self.grus_p_short = nn.ModuleList(grus_p_short)
1022
+
1023
+ grus_p_long = []
1024
+ for _ in range(args.gnn_layers):
1025
+ grus_p_long += [nn.GRUCell(args.hidden_dim, args.hidden_dim)]
1026
+ self.grus_p_long = nn.ModuleList(grus_p_long)
1027
+
1028
+ #近距离全链接层
1029
+ #Fully Connected Layer for Short-Range Features
1030
+ fcs_short = []
1031
+ for _ in range(args.gnn_layers):
1032
+ fcs_short += [nn.Linear(args.hidden_dim * 2, args.hidden_dim)]
1033
+ self.fcs_short = nn.ModuleList(fcs_short)
1034
+
1035
+ # 远距离全连接层
1036
+ # Fully Connected Layer for Long-Range Features
1037
+ fcs_long = []
1038
+ for _ in range(args.gnn_layers):
1039
+ fcs_long += [nn.Linear(args.hidden_dim * 2, args.hidden_dim)]
1040
+ self.fcs_long = nn.ModuleList(fcs_long)
1041
+
1042
+
1043
+ self.fc1 = nn.Linear(args.emb_dim, args.hidden_dim)
1044
+
1045
+ self.nodal_att_type = args.nodal_att_type
1046
+
1047
+ in_dim = (((args.hidden_dim*2))*(args.gnn_layers + 1) + args.emb_dim)
1048
+
1049
+ # output mlp layers
1050
+ layers = [nn.Linear(in_dim, args.hidden_dim), nn.ReLU()]
1051
+ for _ in range(args.mlp_layers - 1):
1052
+ layers += [nn.Linear(args.hidden_dim, args.hidden_dim), nn.ReLU()]
1053
+ layers += [self.dropout]
1054
+ layers += [nn.Linear(args.hidden_dim, num_class)]
1055
+
1056
+ self.out_mlp = nn.Sequential(*layers)
1057
+
1058
+ self.attentive_node_features = attentive_node_features(in_dim)
1059
+
1060
+ self.affine1 = nn.Parameter(torch.empty(size=((args.hidden_dim) , (args.hidden_dim) )))
1061
+ nn.init.xavier_uniform_(self.affine1.data, gain=1.414)
1062
+ self.affine2 = nn.Parameter(torch.empty(size=((args.hidden_dim) , (args.hidden_dim) )))
1063
+ nn.init.xavier_uniform_(self.affine2.data, gain=1.414)
1064
+
1065
+ self.diff_loss = DiffLoss(args)
1066
+ self.beta = args.diffloss
1067
+
1068
+ def forward(self, features, adj_1, adj_2 ,s_mask,s_mask_onehot, lengths):
1069
+ # 检查 H1 和 H2 是否完全相等
1070
+ are_equal = all(torch.equal(h1, h2) for h1, h2 in zip(adj_1, adj_2))
1071
+ # print("adj1 和 adj2 是否完全相等:", are_equal)
1072
+ # print('adj1',adj_1)
1073
+ # print('----------------------------------------------------')
1074
+
1075
+ # print('adj2',adj_2)
1076
+ # print('----------------------------------------------------')
1077
+
1078
+ num_utter = features.size()[1]
1079
+
1080
+ H0 = F.relu(self.fc1(features))
1081
+ #print('H0', H0.size())
1082
+ # H0 = self.dropout(H0)
1083
+ H = [H0]
1084
+ H_combined_short_list = []
1085
+ #对短距离特征进行处理 Process short-range features.
1086
+ for l in range(self.args.gnn_layers):
1087
+ C = self.grus_c_short[l](H[l][:,0,:]).unsqueeze(1) #针对每一层的第一个节点,使用 GRU 单元更新节点特征并聚合信息。For the first node of each layer, use a GRU unit to update the node features and aggregate information.
1088
+ M = torch.zeros_like(C).squeeze(1) #初始化一个聚合信息张量 M(全零张量),并使用它与节点特征结合生成额外的特征 P。Initialize an aggregation tensor M (a zero tensor), and use it together with the node features to generate additional features P.
1089
+ # P = M.unsqueeze(1)
1090
+ P = self.grus_p_short[l](M, H[l][:,0,:]).unsqueeze(1)
1091
+ #H1 = F.relu(self.fcs[l](torch.cat((C,P) , dim = 2)))
1092
+ #H1 = F.relu(C+P)
1093
+ H1 = C+P
1094
+ for i in range(1, num_utter):
1095
+ # print(i,num_utter)
1096
+ if self.args.attn_type == 'rgcn':
1097
+ _, M = self.gather_short[l](H[l][:,i,:], H1, H1, adj_1[:,i,:i], s_mask[:,i,:i])
1098
+ # _, M = self.gather[l](H[l][:,i,:], H1, H1, adj[:,i,:i], s_mask_onehot[:,i,:i,:])
1099
+ else:
1100
+ if not self.rel_attn:
1101
+ _, M = self.gather_short[l](H[l][:,i,:], H1, H1, adj_1[:,i,:i])
1102
+ else:
1103
+ _, M = self.gather_short[l](H[l][:,i,:], H1, H1, adj_1[:,i,:i], s_mask[:, i, :i])
1104
+
1105
+
1106
+ C = self.grus_c_short[l](H[l][:,i,:], M).unsqueeze(1)
1107
+
1108
+ P = self.grus_p_short[l](M, H[l][:,i,:]).unsqueeze(1)
1109
+ # P = M.unsqueeze(1)
1110
+ #H_temp = F.relu(self.fcs[l](torch.cat((C,P) , dim = 2)))
1111
+ #H_temp = F.relu(C+P)
1112
+ H_temp = C+P#将更新后的特征 C 和额外特征 P 进行相加,生成新的节点特征 H_temp
1113
+ H1 = torch.cat((H1 , H_temp), dim = 1) #将当前节点的特征 H_temp 拼接到 H1 中。
1114
+ # print('H1', H1.size())
1115
+ #print('----------------------------------------------------')
1116
+ H.append(H1)
1117
+ H_combined_short_list.append(H[l+1])
1118
+
1119
+ '''
1120
+ 下面对长距离特征进行处理 The following processes the long-distance features.
1121
+ '''
1122
+ H_long = [H0] # 初始化 H_long
1123
+ H_combined_long_list = [] # 存储长距离处理的结果
1124
+
1125
+ # 对长距离特征进行处理
1126
+ for l in range(self.args.gnn_layers):
1127
+ C_long = self.grus_c_long[l](H_long[l][:,0,:]).unsqueeze(1) # 使用 GRU 更新长距离的第一个节点
1128
+ M_long = torch.zeros_like(C_long).squeeze(1) # 初始化长距离的聚合信息张量 M_long
1129
+ P_long = self.grus_p_long[l](M_long, H_long[l][:,0,:]).unsqueeze(1) # 生成额外的特征 P_long
1130
+
1131
+ H1_long = C_long + P_long # 生成新的长距离节点特征 H1_long
1132
+ for i in range(1, num_utter):
1133
+ # 依据不同的 attention 类型,进行特征聚合
1134
+ if self.args.attn_type == 'rgcn':
1135
+ _, M_long = self.gather_long[l](H_long[l][:,i,:], H1_long, H1_long, adj_2[:,i,:i], s_mask[:,i,:i])
1136
+ else:
1137
+ if not self.rel_attn:
1138
+ _, M_long = self.gather_long[l](H_long[l][:,i,:], H1_long, H1_long, adj_2[:,i,:i])
1139
+ else:
1140
+ _, M_long = self.gather_long[l](H_long[l][:,i,:], H1_long, H1_long, adj_2[:,i,:i], s_mask[:,i,:i])
1141
+
1142
+ # 使用 GRU 更新当前节点的特征 C_long 和 M_long
1143
+ C_long = self.grus_c_long[l](H_long[l][:,i,:], M_long).unsqueeze(1)
1144
+ P_long = self.grus_p_long[l](M_long, H_long[l][:,i,:]).unsqueeze(1)
1145
+
1146
+ H_temp_long = C_long + P_long # 将更新后的特征 C_long 和 P_long 相加生成新特征
1147
+ H1_long = torch.cat((H1_long, H_temp_long), dim=1) # 将特征拼接到 H1_long 中
1148
+ H_long.append(H1_long) # 更新 H_long 列表
1149
+ H_combined_long_list.append(H_long[l+1])
1150
+ '''
1151
+ 两个通道特征都提取完毕!Both short- and long-distance channel features have been extracted!
1152
+ '''
1153
+ # print('H_combined_short_list',H_combined_short_list)
1154
+ # print('H_combined_long_list',H_combined_long_list)
1155
+ # are_equal = all(torch.equal(h1, h2) for h1, h2 in zip(H_combined_short_list, H_combined_long_list))
1156
+ # print("H_combined_short_list 和 H_combined_long_list 是否完全相等:", are_equal)
1157
+ # for idx, tensor in enumerate(H_combined_short_list):
1158
+ # print(f"H_combined_short_list[{idx}] shape: {tensor.shape}")
1159
+ H_final = []
1160
+ H_0_final = torch.cat([H0, H0], dim=2)
1161
+ H_final.append(H_0_final)
1162
+ # print("H2 shape:", H2.shape)
1163
+ # 计算差异正则化损失
1164
+ diff_loss = 0
1165
+ for l in range(self.args.gnn_layers):
1166
+ # print('周期:', l)
1167
+ HShort_prime = H_combined_short_list[l]
1168
+ HLong_prime = H_combined_long_list[l]
1169
+ #print("HShort_prime:", HShort_prime.shape)
1170
+ # print("HLong_prime:", HLong_prime.shape)
1171
+ # print("HShort_prime shape:", HShort_prime.shape)
1172
+ # print("HLong_prime shape:", HLong_prime.shape)
1173
+ diff_loss = self.diff_loss(HShort_prime, HLong_prime) + diff_loss
1174
+ #print("diff_loss:", diff_loss)
1175
+ # print(diff_loss.item())
1176
+ # 互交叉注意力机制
1177
+ A1 = F.softmax(torch.bmm(torch.matmul(HShort_prime, self.affine1), torch.transpose(HLong_prime, 1, 2)), dim=2)
1178
+ A2 = F.softmax(torch.bmm(torch.matmul(HLong_prime, self.affine2), torch.transpose(HShort_prime, 1, 2)), dim=2)
1179
+
1180
+ HShort_prime_new = torch.bmm(A1, HLong_prime) # 更新的短时特征
1181
+ HLong_prime_new = torch.bmm(A2, HShort_prime) # 更新的长时特征
1182
+
1183
+ HShort_prime_out = self.dropout(HShort_prime_new) if l < self.args.gnn_layers - 1 else HShort_prime_new
1184
+ HLong_prime_out = self.dropout(HLong_prime_new) if l <self.args.gnn_layers - 1 else HLong_prime_new
1185
+
1186
+ H_layer = torch.cat([HShort_prime_out, HLong_prime_out], dim=2)
1187
+ H_final.append(H_layer)
1188
+ H_final = torch.cat(H_final, dim=2)
1189
+ H_final = torch.cat([H_final, features], dim=2)
1190
+
1191
+
1192
+ # print("H_final shape:", H_final.shape)
1193
+ # print("H:", H.shape)
1194
+ # print("H_final shape after cat:", H_final.shape)
1195
+ H_final = self.attentive_node_features(H_final,lengths,self.nodal_att_type)
1196
+ # print("H_final shape after attentive_node_features:", H_final.shape)
1197
+ logits = self.out_mlp(H_final)
1198
+ # print(diff_loss)
1199
+ return logits, self.beta * diff_loss
model_utils.py ADDED
@@ -0,0 +1,507 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torch.autograd import Variable
5
+ from torch.nn.utils.rnn import pad_sequence
6
+ import numpy as np, itertools, random, copy, math
7
+
8
+
9
+ class DiffLoss(nn.Module):
10
+
11
+ def __init__(self, args):
12
+ super(DiffLoss, self).__init__()
13
+
14
+ def forward(self, input1, input2):
15
+
16
+ # input1 (B,N,D) input2 (B,N,D)
17
+
18
+ batch_size = input1.size(0)
19
+ N = input1.size(1)
20
+ input1 = input1.view(batch_size, -1) # (B,N*D)
21
+ input2 = input2.view(batch_size, -1) # (B, N*D)
22
+ # print('input1:', input1)
23
+ # print('input2:', input2)
24
+ # Zero mean
25
+ input1_mean = torch.mean(input1, dim=0, keepdim=True) # (1,N*D)
26
+ input2_mean = torch.mean(input2, dim=0, keepdim=True) # (1,N*D)
27
+ input1 = input1 - input1_mean # (B,N*D)
28
+ input2 = input2 - input2_mean # (B,N*D)
29
+
30
+ input1_l2_norm = torch.norm(input1, p=2, dim=1, keepdim=True) # (B,1)
31
+ input1_l2 = input1.div(input1_l2_norm.expand_as(input1) + 1e-6) # (B,N*D)
32
+
33
+
34
+ input2_l2_norm = torch.norm(input2, p=2, dim=1, keepdim=True) # (B,1)
35
+ input2_l2 = input2.div(input2_l2_norm.expand_as(input2) + 1e-6) # (B,N*D)
36
+ # print("input1_l2_norm:", input1_l2_norm.detach().cpu().numpy())
37
+ # print("input2_l2_norm:", input2_l2_norm.detach().cpu().numpy())
38
+ # print("input1_l2:", input1_l2.detach().cpu().numpy())
39
+ # print("input2_l2:", input2_l2.detach().cpu().numpy())
40
+ norm_diff = torch.mean(torch.norm(input1_l2 - input2_l2, p=2, dim=1))
41
+ if norm_diff.item() == 0:
42
+ return torch.tensor(float('inf'), device=input1.device)
43
+ diff_loss = 1.0 / norm_diff
44
+
45
+ # print('loss:', diff_loss)
46
+
47
+ return diff_loss
48
+
49
+
50
+ class MaskedNLLLoss(nn.Module):
51
+
52
+ def __init__(self, weight=None):
53
+ super(MaskedNLLLoss, self).__init__()
54
+ self.weight = weight
55
+ self.loss = nn.NLLLoss(weight=weight,
56
+ reduction='sum')
57
+
58
+ def forward(self, pred, target, mask):
59
+ """
60
+ pred -> batch*seq_len, n_classes
61
+ target -> batch*seq_len
62
+ mask -> batch, seq_len
63
+ """
64
+ mask_ = mask.view(-1, 1) # batch*seq_len, 1
65
+ if type(self.weight) == type(None):
66
+ loss = self.loss(pred * mask_, target) / torch.sum(mask)
67
+ else:
68
+ loss = self.loss(pred * mask_, target) \
69
+ / torch.sum(self.weight[target] * mask_.squeeze())
70
+ return loss
71
+
72
+
73
+ class MaskedMSELoss(nn.Module):
74
+
75
+ def __init__(self):
76
+ super(MaskedMSELoss, self).__init__()
77
+ self.loss = nn.MSELoss(reduction='sum')
78
+
79
+ def forward(self, pred, target, mask):
80
+ """
81
+ pred -> batch*seq_len
82
+ target -> batch*seq_len
83
+ mask -> batch*seq_len
84
+ """
85
+ loss = self.loss(pred * mask, target) / torch.sum(mask)
86
+ return loss
87
+
88
+
89
+ class UnMaskedWeightedNLLLoss(nn.Module):
90
+
91
+ def __init__(self, weight=None):
92
+ super(UnMaskedWeightedNLLLoss, self).__init__()
93
+ self.weight = weight
94
+ self.loss = nn.NLLLoss(weight=weight,
95
+ reduction='sum')
96
+
97
+ def forward(self, pred, target):
98
+ """
99
+ pred -> batch*seq_len, n_classes
100
+ target -> batch*seq_len
101
+ """
102
+ if type(self.weight) == type(None):
103
+ loss = self.loss(pred, target)
104
+ else:
105
+ loss = self.loss(pred, target) \
106
+ / torch.sum(self.weight[target])
107
+ return loss
108
+
109
+ class GatedSelection(nn.Module):
110
+ def __init__(self, hidden_size):
111
+ super().__init__()
112
+ self.context_trans = nn.Linear(hidden_size, hidden_size)
113
+ self.linear1 = nn.Linear(hidden_size, hidden_size)
114
+ self.linear2 = nn.Linear(hidden_size, hidden_size)
115
+ self.fc = nn.Linear(hidden_size, hidden_size)
116
+ self.sigmoid = nn.Sigmoid()
117
+ self.relu = nn.ReLU()
118
+
119
+ def forward(self, x1, x2):
120
+ x2 = self.context_trans(x2)
121
+ s = self.sigmoid(self.linear1(x1)+self.linear2(x2))
122
+ h = s * x1 + (1 - s) * x2
123
+ return self.relu(self.fc(h))
124
+
125
+ def mask_logic(alpha, adj):
126
+ '''
127
+ performing mask logic with adj
128
+ :param alpha:
129
+ :param adj:
130
+ :return:
131
+ '''
132
+ return alpha - (1 - adj) * 1e30
133
+
134
+ class GatLinear(nn.Module):
135
+ def __init__(self, hidden_size):
136
+ super().__init__()
137
+ self.linear = nn.Linear(hidden_size * 2, 1)
138
+
139
+
140
+ def forward(self, Q, K, V, adj):
141
+ '''
142
+ imformation gatherer with linear attention
143
+ :param Q: (B, D) # query utterance
144
+ :param K: (B, N, D) # context
145
+ :param V: (B, N, D) # context
146
+ :param adj: (B, N) # the adj matrix of the i th node
147
+ :return:
148
+ '''
149
+ N = K.size()[1]
150
+ # print('Q',Q.size())
151
+ Q = Q.unsqueeze(1).expand(-1, N, -1) # (B, N, D)
152
+ # print('K',K.size())
153
+ X = torch.cat((Q,K), dim = 2) # (B, N, 2D)
154
+ # print('X',X.size())
155
+ alpha = self.linear(X).permute(0,2,1) #(B, 1, N)
156
+ # print('alpha',alpha.size())
157
+ # print(alpha)
158
+ adj = adj.unsqueeze(1)
159
+ alpha = mask_logic(alpha, adj) # (B, 1, N)
160
+ # print('alpha after mask',alpha.size())
161
+ # print(alpha)
162
+
163
+ attn_weight = F.softmax(alpha, dim = 2) # (B, 1, N)
164
+ # print('attn_weight',attn_weight.size())
165
+ # print(attn_weight)
166
+
167
+ attn_sum = torch.bmm(attn_weight, V).squeeze(1) # (B, D)
168
+ # print('attn_sum',attn_sum.size())
169
+
170
+ return attn_weight, attn_sum
171
+
172
+ class GatDot(nn.Module):
173
+ def __init__(self, hidden_size):
174
+ super().__init__()
175
+ self.linear1 = nn.Linear(hidden_size, hidden_size)
176
+ self.linear2 = nn.Linear(hidden_size, hidden_size)
177
+
178
+ def forward(self, Q, K, V, adj):
179
+ '''
180
+ imformation gatherer with dot product attention
181
+ :param Q: (B, D) # query utterance
182
+ :param K: (B, N, D) # context
183
+ :param V: (B, N, D) # context
184
+ :param adj: (B, N) # the adj matrix of the i th node
185
+ :return:
186
+ '''
187
+ N = K.size()[1]
188
+
189
+
190
+ Q = self.linear1(Q).unsqueeze(2) # (B,D,1)
191
+ # K = self.linear2(Q) # (B, N, D)
192
+ K = self.linear2(K) # (B, N, D)
193
+
194
+ alpha = torch.bmm(K, Q).permute(0, 2, 1) # (B, 1, N)
195
+
196
+ adj = adj.unsqueeze(1)
197
+ alpha = mask_logic(alpha, adj) # (B, 1, N)
198
+
199
+ attn_weight = F.softmax(alpha, dim=2) # (B, 1, N)
200
+
201
+ attn_sum = torch.bmm(attn_weight, V).squeeze(1) # (B, D)
202
+
203
+ return attn_weight, attn_sum
204
+
205
+ class GatLinear_rel(nn.Module):
206
+ def __init__(self, hidden_size):
207
+ super().__init__()
208
+ self.linear = nn.Linear(hidden_size * 3, 1)
209
+ self.rel_emb = nn.Embedding(2, hidden_size)
210
+
211
+
212
+ def forward(self, Q, K, V, adj, s_mask):
213
+ '''
214
+ imformation gatherer with linear attention
215
+ :param Q: (B, D) # query utterance
216
+ :param K: (B, N, D) # context
217
+ :param V: (B, N, D) # context
218
+ :param adj: (B, N) # the adj matrix of the i th node
219
+ :param s_mask: (B, N) #
220
+ :return:
221
+ '''
222
+ rel_emb = self.rel_emb(s_mask) # (B, N, D)
223
+ N = K.size()[1]
224
+ # print('Q',Q.size())
225
+ Q = Q.unsqueeze(1).expand(-1, N, -1) # (B, N, D)
226
+ # print('K',K.size())
227
+ # print('rel_emb', rel_emb.size())
228
+ X = torch.cat((Q,K, rel_emb), dim = 2) # (B, N, 2D)? (B, N, 3D)
229
+ # print('X',X.size())
230
+ alpha = self.linear(X).permute(0,2,1) #(B, 1, N)
231
+ # print('alpha',alpha.size())
232
+ # print(alpha)
233
+ adj = adj.unsqueeze(1)
234
+ alpha = mask_logic(alpha, adj) # (B, 1, N)
235
+ # print('alpha after mask',alpha.size())
236
+ # print(alpha)
237
+
238
+ attn_weight = F.softmax(alpha, dim = 2) # (B, 1, N)
239
+ # print('attn_weight',attn_weight.size())
240
+ # print(attn_weight)
241
+
242
+ attn_sum = torch.bmm(attn_weight, V).squeeze(1) # (B, D)
243
+ # print('attn_sum',attn_sum.size())
244
+
245
+ return attn_weight, attn_sum
246
+
247
+
248
+ class GatDot_rel(nn.Module):
249
+ def __init__(self, hidden_size):
250
+ super().__init__()
251
+ self.linear1 = nn.Linear(hidden_size, hidden_size)
252
+ self.linear2 = nn.Linear(hidden_size, hidden_size)
253
+ self.linear3 = nn.Linear(hidden_size, 1)
254
+ self.rel_emb = nn.Embedding(2, hidden_size)
255
+
256
+ def forward(self, Q, K, V, adj, s_mask):
257
+ '''
258
+ imformation gatherer with dot product attention
259
+ :param Q: (B, D) # query utterance
260
+ :param K: (B, N, D) # context
261
+ :param V: (B, N, D) # context
262
+ :param adj: (B, N) # the adj matrix of the i th node
263
+ :param s_mask: (B, N) # relation mask
264
+ :return:
265
+ '''
266
+ N = K.size()[1]
267
+
268
+ rel_emb = self.rel_emb(s_mask)
269
+ Q = self.linear1(Q).unsqueeze(2) # (B,D,1)
270
+ K = self.linear2(K) # (B, N, D)
271
+ y = self.linear3(rel_emb) # (B, N, 1)
272
+
273
+ alpha = (torch.bmm(K, Q) + y).permute(0, 2, 1) # (B, 1, N)
274
+
275
+ adj = adj.unsqueeze(1)
276
+ alpha = mask_logic(alpha, adj) # (B, 1, N)
277
+
278
+ attn_weight = F.softmax(alpha, dim=2) # (B, 1, N)
279
+
280
+ attn_sum = torch.bmm(attn_weight, V).squeeze(1) # (B, D)
281
+
282
+ return attn_weight, attn_sum
283
+
284
+
285
+ class GAT_dialoggcn(nn.Module):
286
+ '''
287
+ H_i = alpha_ij(W_rH_j)
288
+ alpha_ij = attention(H_i, H_j)
289
+ '''
290
+ def __init__(self, hidden_size):
291
+ super().__init__()
292
+ self.hidden_size = hidden_size
293
+ self.linear = nn.Linear(hidden_size * 2, 1)
294
+ self.rel_emb = nn.Parameter(torch.randn(2, hidden_size, hidden_size))
295
+
296
+ def forward(self, Q, K, V, adj, s_mask_onehot):
297
+ '''
298
+ imformation gatherer with linear attention
299
+ :param Q: (B, D) # query utterance
300
+ :param K: (B, N, D) # context
301
+ :param V: (B, N, D) # context
302
+ :param adj: (B, N) # the adj matrix of the i th node
303
+ :param s_mask: (B, N, 2) #
304
+ :return:
305
+ '''
306
+ B = K.size()[0]
307
+ N = K.size()[1]
308
+ # print('Q',Q.size())
309
+ Q = Q.unsqueeze(1).expand(-1, N, -1) # (B, N, D);
310
+ # print('K',K.size())
311
+ X = torch.cat((Q,K), dim = 2) # (B, N, 2D)
312
+ # print('X',X.size())
313
+ alpha = self.linear(X).permute(0,2,1) #(B, 1, N)
314
+ # print('alpha',alpha.size())
315
+ # print(alpha)
316
+ adj = adj.unsqueeze(1)
317
+ alpha = mask_logic(alpha, adj) # (B, 1, N)
318
+ # print('alpha after mask',alpha.size())
319
+ # print(alpha)
320
+
321
+ attn_weight = F.softmax(alpha, dim = 2) # (B, 1, N)
322
+ # print('attn_weight',attn_weight.size())
323
+ # print(attn_weight)
324
+
325
+ # print('s_mask_onehot', s_mask_onehot.size())
326
+ D = self.rel_emb.size()[2]
327
+ # print('rel_emb', self.rel_emb.size())
328
+ rel_emb = self.rel_emb.unsqueeze(0).expand(B,-1,-1,-1)
329
+ # rel_emb = self.rel_emb.unsqueeze(0).repeat(B, 1, 1, 1)
330
+ # print('rel_emb expand', rel_emb.size())
331
+
332
+ rel_emb = rel_emb.reshape((B, 2, D*D))
333
+ # print('rel_emb resize', rel_emb.size())
334
+ Wr = torch.bmm(s_mask_onehot, rel_emb).reshape((B, N, D, D)) # (B, N, D, D)
335
+ # print('Wr', Wr.size()) # (B, N, D, D)
336
+
337
+ Wr = Wr.reshape((B*N, D, D))
338
+ # print('Wr after reshape', Wr.size())
339
+
340
+ V = V.unsqueeze(2).reshape((B*N, 1, -1)) # (B*N, 1, D)
341
+ # print('V after reshape', V.size())
342
+ V = torch.bmm(V, Wr).unsqueeze(1) #(B * N, D)
343
+ # print('V after transform', V.size())
344
+ V = V.reshape((B,N,-1))
345
+ # print('Final V', V.size())
346
+
347
+ attn_sum = torch.bmm(attn_weight, V).squeeze(1) # (B, D)
348
+ # print('attn_sum',attn_sum.size())
349
+
350
+ return attn_weight, attn_sum
351
+
352
+
353
+ class GAT_dialoggcn_v1(nn.Module):
354
+ '''
355
+ use linear to avoid OOM
356
+ H_i = alpha_ij(W_rH_j)
357
+ alpha_ij = attention(H_i, H_j)
358
+ '''
359
+ def __init__(self, hidden_size):
360
+ super().__init__()
361
+ self.hidden_size = hidden_size
362
+ self.linear = nn.Linear(hidden_size * 2, 1)
363
+ self.Wr0 = nn.Linear(hidden_size, hidden_size, bias = False)
364
+ self.Wr1 = nn.Linear(hidden_size, hidden_size, bias = False)
365
+
366
+ def forward(self, Q, K, V, adj, s_mask):
367
+ '''
368
+ imformation gatherer with linear attention
369
+ :param Q: (B, D) # query utterance
370
+ :param K: (B, N, D) # context
371
+ :param V: (B, N, D) # context
372
+ :param adj: (B, N) # the adj matrix of the i th node
373
+ :param s_mask: (B, N) #
374
+ :return:
375
+ '''
376
+ B = K.size()[0]
377
+ N = K.size()[1]
378
+ # print('Q',Q.size())
379
+ Q = Q.unsqueeze(1).expand(-1, N, -1) # (B, N, D);
380
+ # print('K',K.size())
381
+ X = torch.cat((Q,K), dim = 2) # (B, N, 2D)
382
+ # print('X',X.size())
383
+ alpha = self.linear(X).permute(0,2,1) #(B, 1, N)
384
+ #alpha = F.leaky_relu(alpha)
385
+ # print('alpha',alpha.size())
386
+ # print(alpha)
387
+ adj = adj.unsqueeze(1) # (B, 1, N)
388
+ alpha = mask_logic(alpha, adj) # (B, 1, N)
389
+ # print('alpha after mask',alpha.size())
390
+ # print(alpha)
391
+
392
+ attn_weight = F.softmax(alpha, dim = 2) # (B, 1, N)
393
+ # print('attn_weight',attn_weight.size())
394
+ # print(attn_weight)
395
+
396
+ V0 = self.Wr0(V) # (B, N, D)
397
+ V1 = self.Wr1(V) # (B, N, D)
398
+
399
+ s_mask = s_mask.unsqueeze(2).float() # (B, N, 1)
400
+ V = V0 * s_mask + V1 * (1 - s_mask)
401
+
402
+ attn_sum = torch.bmm(attn_weight, V).squeeze(1) # (B, D)
403
+ # print('attn_sum',attn_sum.size())
404
+
405
+ return attn_weight, attn_sum
406
+
407
+
408
+ class GAT_dialoggcn_v2(nn.Module):
409
+ '''
410
+ use linear to avoid OOM
411
+ H_i = alpha_ij(W_rH_j)
412
+ alpha_ij = attention(H_i, H_j, rel)
413
+ '''
414
+ def __init__(self, hidden_size):
415
+ super().__init__()
416
+ self.hidden_size = hidden_size
417
+ self.linear = nn.Linear(hidden_size * 3, 1)
418
+ self.Wr0 = nn.Linear(hidden_size, hidden_size, bias = False)
419
+ self.Wr1 = nn.Linear(hidden_size, hidden_size, bias = False)
420
+ self.rel_emb = nn.Embedding(2, hidden_size)
421
+
422
+ def forward(self, Q, K, V, adj, s_mask):
423
+ '''
424
+ imformation gatherer with linear attention
425
+ :param Q: (B, D) # query utterance
426
+ :param K: (B, N, D) # context
427
+ :param V: (B, N, D) # context
428
+ :param adj: (B, N) # the adj matrix of the i th node
429
+ :param s_mask: (B, N) #
430
+ :return:
431
+ '''
432
+ rel_emb = self.rel_emb(s_mask) # (B, N, D)
433
+ B = K.size()[0]
434
+ N = K.size()[1]
435
+ # print('Q',Q.size())
436
+ Q = Q.unsqueeze(1).expand(-1, N, -1) # (B, N, D);
437
+ # print('K',K.size())
438
+ X = torch.cat((Q,K,rel_emb), dim = 2) # (B, N, 3D)
439
+ # print('X',X.size())
440
+ alpha = self.linear(X).permute(0,2,1) #(B, 1, N)
441
+ # print('alpha',alpha.size())
442
+ # print(alpha)
443
+ adj = adj.unsqueeze(1)
444
+ alpha = mask_logic(alpha, adj) # (B, 1, N)
445
+ # print('alpha after mask',alpha.size())
446
+ # print(alpha)
447
+
448
+ attn_weight = F.softmax(alpha, dim = 2) # (B, 1, N)
449
+ # print('attn_weight',attn_weight.size())
450
+ # print(attn_weight)
451
+
452
+ V0 = self.Wr0(V) # (B, N,D)
453
+ V1 = self.Wr1(V) # (B, N, D)
454
+
455
+ s_mask = s_mask.unsqueeze(2).float()
456
+ V = V0 * s_mask + V1 * (1 - s_mask)
457
+
458
+ attn_sum = torch.bmm(attn_weight, V).squeeze(1) # (B, D)
459
+ # print('attn_sum',attn_sum.size())
460
+
461
+ return attn_weight, attn_sum
462
+
463
+
464
+ class attentive_node_features(nn.Module):
465
+ '''
466
+ Method to obtain attentive node features over the graph convoluted features
467
+ '''
468
+ def __init__(self, hidden_size):
469
+ super().__init__()
470
+ self.transform = nn.Linear(hidden_size, hidden_size)
471
+
472
+ def forward(self,features, lengths, nodal_att_type):
473
+ '''
474
+ features : (B, N, V)
475
+ lengths : (B, )
476
+ nodal_att_type : type of the final nodal attention
477
+ '''
478
+
479
+ if nodal_att_type==None:
480
+ return features
481
+
482
+ batch_size = features.size(0)
483
+ max_seq_len = features.size(1)
484
+ padding_mask = [l*[1]+(max_seq_len-l)*[0] for l in lengths]
485
+ padding_mask = torch.tensor(padding_mask).to(features) # (B, N)
486
+ causal_mask = torch.ones(max_seq_len, max_seq_len).to(features) # (N, N)
487
+ causal_mask = torch.tril(causal_mask).unsqueeze(0) # (1, N, N)
488
+
489
+ if nodal_att_type=='global':
490
+ mask = padding_mask.unsqueeze(1)
491
+ elif nodal_att_type=='past':
492
+ mask = padding_mask.unsqueeze(1)*causal_mask
493
+
494
+ x = self.transform(features) # (B, N, V)
495
+ temp = torch.bmm(x, features.permute(0,2,1))
496
+ #print(temp)
497
+ alpha = F.softmax(torch.tanh(temp), dim=2) # (B, N, N)
498
+ alpha_masked = alpha*mask # (B, N, N)
499
+
500
+ alpha_sum = torch.sum(alpha_masked, dim=2, keepdim=True) # (B, N, 1)
501
+ #print(alpha_sum)
502
+ alpha = alpha_masked / alpha_sum # (B, N, N)
503
+ attn_pool = torch.bmm(alpha, features) # (B, N, V)
504
+
505
+ return attn_pool
506
+
507
+
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ torch==2.0.0+cu117
2
+ transformers==4.46.3
3
+ numpy==1.24.2
4
+ pandas==2.1.4
5
+ matplotlib==3.7.1
6
+ scikit-learn==1.2.2
7
+ tqdm==4.67.1
run.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.environ["CUDA_VISIBLE_DEVICES"] = '1'
3
+ import numpy as np, argparse, time, pickle, random
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.optim as optim
7
+ from dataloader import IEMOCAPDataset, get_train_loader
8
+ from model import *
9
+ from sklearn.metrics import f1_score, confusion_matrix, accuracy_score, classification_report, \
10
+ precision_recall_fscore_support
11
+ from trainer import train_or_eval_model, save_badcase
12
+ from dataset import IEMOCAPDataset
13
+ from dataloader import get_IEMOCAP_loaders
14
+ from transformers import AdamW
15
+ import copy
16
+
17
+ # We use seed = 100 for reproduction of the results reported in the paper.
18
+ seed = 100
19
+
20
+ import logging
21
+
22
+ def get_logger(filename, verbosity=1, name=None):
23
+ level_dict = {0: logging.DEBUG, 1: logging.INFO, 2: logging.WARNING}
24
+ formatter = logging.Formatter(
25
+ "[%(asctime)s][%(filename)s][line:%(lineno)d][%(levelname)s] %(message)s"
26
+ )
27
+ logger = logging.getLogger(name)
28
+ logger.setLevel(level_dict[verbosity])
29
+
30
+ fh = logging.FileHandler(filename, "w")
31
+ fh.setFormatter(formatter)
32
+ logger.addHandler(fh)
33
+
34
+ sh = logging.StreamHandler()
35
+ sh.setFormatter(formatter)
36
+ logger.addHandler(sh)
37
+
38
+ return logger
39
+
40
+ def seed_everything(seed=seed):
41
+ random.seed(seed)
42
+ np.random.seed(seed)
43
+ torch.manual_seed(seed)
44
+ torch.cuda.manual_seed(seed)
45
+ torch.cuda.manual_seed_all(seed)
46
+ torch.backends.cudnn.benchmark = False
47
+ torch.backends.cudnn.deterministic = True
48
+
49
+
50
+ if __name__ == '__main__':
51
+
52
+ path = './saved_models/'
53
+
54
+ parser = argparse.ArgumentParser()
55
+ parser.add_argument('--bert_model_dir', type=str, default='')
56
+ parser.add_argument('--bert_tokenizer_dir', type=str, default='')
57
+
58
+
59
+ parser.add_argument('--bert_dim', type = int, default=1024)
60
+ parser.add_argument('--hidden_dim', type = int, default=300)
61
+ parser.add_argument('--mlp_layers', type=int, default=2, help='Number of output mlp layers.')
62
+ parser.add_argument('--gnn_layers', type=int, default=2, help='Number of gnn layers.')
63
+ parser.add_argument('--emb_dim', type=int, default=1024, help='Feature size.')
64
+
65
+ parser.add_argument('--attn_type', type=str, default='rgcn', choices=['dotprod','linear','bilinear', 'rgcn'], help='Feature size.')
66
+ parser.add_argument('--no_rel_attn', action='store_true', default=False, help='no relation for edges' )
67
+
68
+ parser.add_argument('--max_sent_len', type=int, default=200,
69
+ help='max content length for each text, if set to 0, then the max length has no constrain')
70
+
71
+ parser.add_argument('--no_cuda', action='store_true', default=False, help='does not use GPU')
72
+
73
+ parser.add_argument('--dataset_name', default='IEMOCAP', type= str, help='dataset name, IEMOCAP or MELD or DailyDialog')
74
+
75
+ parser.add_argument('--windowps', type=int, default=1,
76
+ help='context window size for constructing edges in graph model for past utterances for short')
77
+ parser.add_argument('--windowpl', type=int, default=5,
78
+ help='context window size for constructing edges in graph model for past utterances for long')
79
+
80
+ parser.add_argument('--windowf', type=int, default=0,
81
+ help='context window size for constructing edges in graph model for future utterances')
82
+
83
+ parser.add_argument('--max_grad_norm', type=float, default=5.0, help='Gradient clipping.')
84
+
85
+ parser.add_argument('--lr', type=float, default=1e-3, metavar='LR', help='learning rate')
86
+
87
+ parser.add_argument('--dropout', type=float, default=0, metavar='dropout', help='dropout rate')
88
+
89
+ parser.add_argument('--batch_size', type=int, default=16, metavar='BS', help='batch size')
90
+
91
+ parser.add_argument('--epochs', type=int, default=20, metavar='E', help='number of epochs')
92
+
93
+ parser.add_argument('--tensorboard', action='store_true', default=False, help='Enables tensorboard log')
94
+
95
+ parser.add_argument('--nodal_att_type', type=str, default=None, choices=['global','past'], help='type of nodal attention')
96
+
97
+ parser.add_argument('--curriculum', action='store_true', default=False, help='Enables curriculum learning')
98
+
99
+ parser.add_argument('--bucket_number', type=int, default=0)
100
+
101
+ parser.add_argument('--max_epoch_per_baby_step', type=int, default=0)
102
+
103
+ parser.add_argument('--diffloss', type=float , default=0.1, help='diffloss beta')
104
+
105
+ args = parser.parse_args()
106
+ print(args)
107
+
108
+ seed_everything()
109
+
110
+ args.cuda = torch.cuda.is_available() and not args.no_cuda
111
+
112
+ if args.cuda:
113
+ print('Running on GPU')
114
+ else:
115
+ print('Running on CPU')
116
+
117
+ if args.tensorboard:
118
+ from tensorboardX import SummaryWriter
119
+
120
+ writer = SummaryWriter()
121
+
122
+ logger = get_logger(path + args.dataset_name + '/logging.log')
123
+ logger.info('start training on GPU {}!'.format(os.environ["CUDA_VISIBLE_DEVICES"]))
124
+ logger.info(args)
125
+
126
+ cuda = args.cuda
127
+ n_epochs = args.epochs
128
+ batch_size = args.batch_size
129
+ valid_loader, test_loader, speaker_vocab, label_vocab, person_vec = get_IEMOCAP_loaders(dataset_name=args.dataset_name, batch_size=batch_size, num_workers=0, args = args)
130
+ n_classes = len(label_vocab['itos'])
131
+
132
+ print('building model..')
133
+ model = DAGERC_new_4(args, n_classes)
134
+ if args.dataset_name == 'IEMOCAP':
135
+ class_labels = ['excitement', 'neutral', 'frustration', 'sadness', 'happiness', 'anger']
136
+ else:
137
+ class_labels = ['Neutral', 'Surprise', 'Fear', 'Sadness', 'Joy', 'Disgust', 'Anger']
138
+
139
+ if torch.cuda.device_count() > 1:
140
+ print('Multi-GPU...........')
141
+ model = nn.DataParallel(model,device_ids = range(torch.cuda.device_count()))
142
+ if cuda:
143
+ model.cuda()
144
+
145
+ loss_function = nn.CrossEntropyLoss(ignore_index=-1)
146
+ optimizer = AdamW(model.parameters() , lr=args.lr)
147
+
148
+ best_fscore,best_acc, best_loss, best_label, best_pred, best_mask = None,None, None, None, None, None
149
+ all_fscore, all_acc, all_loss = [], [], []
150
+ best_acc = 0.
151
+ best_fscore = 0.
152
+ best_epoch = 0
153
+ best_model = None
154
+ for e in range(n_epochs):
155
+ start_time = time.time()
156
+ #for curiculum learning
157
+ if e + 1 < args.bucket_number:
158
+ train_loader = get_train_loader(dataset_name=args.dataset_name, batch_size=batch_size, num_workers=0,
159
+ args=args, babystep_index=e + 1)
160
+ else:
161
+ train_loader = get_train_loader(dataset_name=args.dataset_name, batch_size=batch_size, num_workers=0,
162
+ args=args, babystep_index=args.bucket_number)
163
+ if args.dataset_name == 'DailyDialog':
164
+ train_loss, train_acc, _, _, train_micro_fscore, train_macro_fscore = train_or_eval_model(model,
165
+ loss_function,
166
+ train_loader, e,
167
+ cuda,
168
+ args, optimizer,
169
+ True)
170
+ valid_loss, valid_acc, _, _, valid_micro_fscore, valid_macro_fscore = train_or_eval_model(model,
171
+ loss_function,
172
+ valid_loader, e,
173
+ cuda, args)
174
+ test_loss, test_acc, test_label, test_pred, test_micro_fscore, test_macro_fscore = train_or_eval_model(
175
+ model, loss_function, test_loader, e, cuda, args)
176
+
177
+ all_fscore.append([valid_micro_fscore, test_micro_fscore, valid_macro_fscore, test_macro_fscore])
178
+
179
+ logger.info( 'Epoch: {}, train_loss: {}, train_acc: {}, train_micro_fscore: {}, train_macro_fscore: {}, valid_loss: {}, valid_acc: {}, valid_micro_fscore: {}, valid_macro_fscore: {}, test_loss: {}, test_acc: {}, test_micro_fscore: {}, test_macro_fscore: {}, time: {} sec'. \
180
+ format(e + 1, train_loss, train_acc, train_micro_fscore, train_macro_fscore, valid_loss, valid_acc, valid_micro_fscore, valid_macro_fscore, test_loss, test_acc,
181
+ test_micro_fscore, test_macro_fscore, round(time.time() - start_time, 2)))
182
+
183
+ else:
184
+ train_loss, train_acc, _, _, train_fscore, _ , _ = train_or_eval_model(model, loss_function,
185
+ train_loader, e, cuda,
186
+ args, optimizer, True)
187
+ valid_loss, valid_acc, _, _, valid_fscore, _ , _= train_or_eval_model(model, loss_function,
188
+ valid_loader, e, cuda, args)
189
+ test_loss, test_acc, test_label, test_pred, test_fscore, test_f1_per_class, avg_macro_fscore= train_or_eval_model(model,loss_function, test_loader, e, cuda, args)
190
+
191
+ all_fscore.append([valid_fscore, test_fscore])
192
+
193
+ logger.info(
194
+ 'Epoch: {}, train_loss: {}, train_acc: {}, train_fscore: {}, valid_loss: {}, valid_acc: {}, valid_fscore: {}, test_loss: {}, test_acc: {}, test_fscore: {}, avg_macro_fscore: {}, time: {} sec'. \
195
+ format(e + 1, train_loss, train_acc, train_fscore, valid_loss, valid_acc, valid_fscore, test_loss,
196
+ test_acc,
197
+ test_fscore, avg_macro_fscore, round(time.time() - start_time, 2)))
198
+
199
+ f1_with_labels = {label: f1 for label, f1 in zip(class_labels, test_f1_per_class)}
200
+
201
+ logger.info(f"Test F1 per class: {f1_with_labels}")
202
+
203
+ if (test_fscore > best_fscore):
204
+ best_fscore = test_fscore
205
+ best_model = copy.deepcopy(model.state_dict())
206
+ # print(test_fscore)
207
+ # print(best_model)
208
+ best_epoch = e + 1
209
+ # torch.save(model.state_dict(), path + args.dataset_name + '/model_' + str(e) + '_' + str(test_acc)+ '.pkl')
210
+
211
+ e += 1
212
+ #save model
213
+ torch.save(best_model, path + args.dataset_name + '/model_' + str(best_epoch) + '_' + str(best_fscore) + '_' + str(
214
+ args.gnn_layers) + '.pkl')
215
+ # print(best_model)
216
+ if args.tensorboard:
217
+ writer.close()
218
+
219
+ logger.info('finish training!')
220
+
221
+ #print('Test performance..')
222
+ all_fscore = sorted(all_fscore, key=lambda x: (x[0],x[1]), reverse=True)
223
+ #print('Best F-Score based on validation:', all_fscore[0][1])
224
+ #print('Best F-Score based on test:', max([f[1] for f in all_fscore]))
225
+
226
+ #logger.info('Test performance..')
227
+ #logger.info('Best F-Score based on validation:{}'.format(all_fscore[0][1]))
228
+ #logger.info('Best F-Score based on test:{}'.format(max([f[1] for f in all_fscore])))
229
+
230
+ if args.dataset_name=='DailyDialog':
231
+ logger.info('Best micro/macro F-Score based on validation:{}/{}'.format(all_fscore[0][1],all_fscore[0][3]))
232
+ all_fscore = sorted(all_fscore, key=lambda x: x[1], reverse=True)
233
+ logger.info('Best micro/macro F-Score based on test:{}/{}'.format(all_fscore[0][1],all_fscore[0][3]))
234
+ else:
235
+ logger.info('Best F-Score based on validation:{}'.format(all_fscore[0][1]))
236
+ logger.info('Best F-Score based on test:{}'.format(max([f[1] for f in all_fscore])))
237
+
238
+ #save_badcase(best_model, test_loader, cuda, args, speaker_vocab, label_vocab)
239
+
saved_models/IEMOCAP/README.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ 请在此处存放训练后IEMOCAP的模型。
2
+ Please store the trained IEMOCAP model here.
saved_models/MELD/README.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ 请在此处存放训练后MELD的模型。
2
+ Please store the trained MELD model here.
saved_models/README.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ 请在此处的IEMOCAP和MELD文件夹中存放训练后的模型。
2
+ Please store the trained models in the IEMOCAP and MELD folders here.
similarity_matrix.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+
4
+ emotion_positions = {
5
+ "pleasure": (np.cos(np.pi / 20), np.sin(np.pi / 20)),
6
+ "happiness": (np.cos(3 * np.pi / 20), np.sin(3 * np.pi / 20)),
7
+ "joy": (np.cos(3 * np.pi / 20), np.sin(3 * np.pi / 20)),
8
+ "pride": (np.cos(5 * np.pi / 20), np.sin(5 * np.pi / 20)),
9
+ "elation": (np.cos(5 * np.pi / 20), np.sin(5 * np.pi / 20)),
10
+ "excitement": (np.cos(7 * np.pi / 20), np.sin(7 * np.pi / 20)),
11
+ "surprise": (np.cos(9 * np.pi / 20), np.sin(9 * np.pi / 20)),
12
+ "interest": (np.cos(9 * np.pi / 20), np.sin(9 * np.pi / 20)),
13
+ "anger": (-np.cos(9 * np.pi / 20), np.sin(9 * np.pi / 20)),
14
+ "irritation": (-np.cos(9 * np.pi / 20), np.sin(9 * np.pi / 20)),
15
+ "hate": (-np.cos(7 * np.pi / 20), np.sin(7 * np.pi / 20)),
16
+ "contempt": (-np.cos(5 * np.pi / 20), np.sin(5 * np.pi / 20)),
17
+ "disgust": (-np.cos(3 * np.pi / 20), np.sin(3 * np.pi / 20)),
18
+ "fear": (-np.cos(np.pi / 20), np.sin(np.pi / 20)),
19
+ "boredom": (-0.5, 0),
20
+ "disappointment": (-np.cos(np.pi / 20), -np.sin(np.pi / 20)),
21
+ "frustration": (-np.cos(np.pi / 20), -np.sin(np.pi / 20)),
22
+ "shame": (-np.cos(3 * np.pi / 20), -np.sin(3 * np.pi / 20)),
23
+ "regret": (-np.cos(5 * np.pi / 20), -np.sin(5 * np.pi / 20)),
24
+ "guilt": (-np.cos(7 * np.pi / 20), -np.sin(7 * np.pi / 20)),
25
+ "sadness": (-np.cos(9 * np.pi / 20), -np.sin(9 * np.pi / 20)),
26
+ "compassion": (np.cos(9 * np.pi / 20), -np.sin(9 * np.pi / 20)),
27
+ "relief": (np.cos(7 * np.pi / 20), -np.sin(7 * np.pi / 20)),
28
+ "admiration": (np.cos(5 * np.pi / 20), -np.sin(5 * np.pi / 20)),
29
+ "love": (np.cos(3 * np.pi / 20), -np.sin(3 * np.pi / 20)),
30
+ "contentment": (np.cos(np.pi / 20), -np.sin(np.pi / 20)),
31
+ "neutral": (0, 0)
32
+ }
33
+
34
+ # 计算两个情感标签之间的余弦相似度 Compute the cosine similarity between two sentiment labels.
35
+ def cosine_similarity(p1, p2):
36
+ dot_product = np.dot(p1, p2)
37
+ norm_p1 = np.linalg.norm(p1)
38
+ norm_p2 = np.linalg.norm(p2)
39
+ if norm_p1 == 0 or norm_p2 == 0:
40
+ return 0.0
41
+ # print(norm_p2)
42
+ return dot_product / (norm_p1 * norm_p2)
43
+
44
+ # 计算情感标签之间的相似度矩阵 Compute the similarity matrix between sentiment labels.
45
+ def compute_similarity_matrix(emotion_positions, n_dataset):
46
+ emotions = list(emotion_positions.keys())
47
+ # print(emotions)
48
+ n = len(emotions)
49
+ N = n_dataset # 总情感标签数 Total number of sentiment labels.
50
+ # print(N)
51
+ similarity_matrix = np.zeros((n, n))
52
+ emotion_to_index = {emotion: idx for idx, emotion in enumerate(emotions)} # 标签到索引的映射 Mapping from labels to indices.
53
+ # 获取 "neutral" 标签的索引 Get the index of the label "neutral".
54
+ neutral_index = emotions.index("neutral")
55
+
56
+ for i in range(n):
57
+ for j in range(n):
58
+ if i != j:
59
+ v1 = emotion_positions[emotions[i]][0]
60
+ v2 = emotion_positions[emotions[j]][0]
61
+ # print(v1)
62
+ p1 = emotion_positions[emotions[i]]
63
+ # print(p1)
64
+ p2 = emotion_positions[emotions[j]]
65
+ # print(v1)
66
+ # 如果两个情感标签的价度极性相反,设相似度为0 If the valence polarities of two emotion labels are opposite, set their similarity to 0.
67
+ if v1 * v2 < 0:
68
+ similarity_matrix[i][j] = 0
69
+ elif v1 * v2 == 0: # valence 极性为 0
70
+ similarity_matrix[i][j] = 1 / N # 设置为 1/N set to 1/N
71
+ else:
72
+
73
+ # print(v1)
74
+ # print(v2)
75
+ # print('-----')
76
+ similarity_matrix[i][j] = max(cosine_similarity(np.array(p1), np.array(p2)), 0)
77
+
78
+ # 特殊处理 "neutral" 标签 Special handling for the "neutral" label.
79
+ for i in range(n):
80
+ if i != neutral_index:
81
+ similarity_matrix[neutral_index][i] = 1 / N # 与所有其他标签的相似度为 1/N The similarity between the "neutral" label and all other labels is set to 1/N.
82
+ similarity_matrix[i][neutral_index] = 1 / N
83
+
84
+ return similarity_matrix, emotion_to_index
85
+
86
+
87
+ def get_similarity_matrix(dataset):
88
+
89
+ if dataset == 'IEMOCAP':
90
+ n = 6
91
+ similarity_matrix , emotion_to_index = compute_similarity_matrix(emotion_positions, n)
92
+ else:
93
+ n = 7
94
+ similarity_matrix , emotion_to_index = compute_similarity_matrix(emotion_positions, n)
95
+
96
+ #输出相似度矩阵
97
+ #绝对值越接近1表示越相似,越接近0表示越不一样 The closer the absolute value is to 1, the more similar they are; the closer it is to 0, the more different they are.
98
+ # print("Emotion Similarity Matrix:")
99
+ # print(similarity_matrix)
100
+ # print(emotion_to_index)
101
+ return similarity_matrix,emotion_to_index
trainer.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np, argparse, time, pickle, random
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.optim as optim
5
+ from torch.utils.data import DataLoader
6
+ from torch.utils.data.sampler import SubsetRandomSampler
7
+ from dataloader import IEMOCAPDataset
8
+ from sklearn.metrics import f1_score, confusion_matrix, accuracy_score, classification_report, \
9
+ precision_recall_fscore_support
10
+ from utils import person_embed
11
+ from tqdm import tqdm
12
+ import json
13
+
14
+
15
+ def train_or_eval_model(model, loss_function, dataloader,epoch, cuda, args, optimizer=None, train=False):
16
+ losses, preds, labels = [], [], []
17
+ scores, vids = [], []
18
+
19
+
20
+ assert not train or optimizer != None
21
+ if train:
22
+ model.train()
23
+ # dataloader = tqdm(dataloader)
24
+ else:
25
+ model.eval()
26
+
27
+ cnt = 0
28
+ for data in dataloader:
29
+ if train:
30
+ optimizer.zero_grad()
31
+ # text_ids, text_feature, speaker_ids, labels, umask = [d.cuda() for d in data] if cuda else data
32
+ features, label, adj_1, adj_2, s_mask, s_mask_onehot,lengths, speakers, utterances = data
33
+ # speaker_vec = person_embed(speaker_ids, person_vec)
34
+ if cuda:
35
+ features = features.cuda()
36
+ label = label.cuda()
37
+ adj_1 = adj_1.cuda()
38
+ adj_2 = adj_2.cuda()
39
+ s_mask = s_mask.cuda()
40
+ s_mask_onehot = s_mask_onehot.cuda()
41
+ lengths = lengths.cuda()
42
+
43
+
44
+ # print(speakers)
45
+ log_prob, diff_loss = model(features, adj_1, adj_2, s_mask, s_mask_onehot, lengths) # (B, N, C)
46
+ # print(label)
47
+ loss = loss_function(log_prob.permute(0,2,1), label)+ diff_loss
48
+ '''
49
+ # print(speakers)
50
+ log_prob = model(features, adj_1, adj_2, s_mask, s_mask_onehot, lengths) # (B, N, C)
51
+ # print(label)
52
+ loss = loss_function(log_prob.permute(0,2,1), label)
53
+ '''
54
+ label = label.cpu().numpy().tolist()
55
+ pred = torch.argmax(log_prob, dim = 2).cpu().numpy().tolist()
56
+ preds += pred
57
+ labels += label
58
+ losses.append(loss.item())
59
+
60
+
61
+ if train:
62
+ loss_val = loss.item()
63
+ loss.backward()
64
+ torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
65
+ if args.tensorboard:
66
+ for param in model.named_parameters():
67
+ writer.add_histogram(param[0], param[1].grad, epoch)
68
+ optimizer.step()
69
+
70
+ if preds != []:
71
+ new_preds = []
72
+ new_labels = []
73
+ for i,label in enumerate(labels):
74
+ for j,l in enumerate(label):
75
+ if l != -1:
76
+ new_labels.append(l)
77
+ new_preds.append(preds[i][j])
78
+ else:
79
+ return float('nan'), float('nan'), [], [], float('nan'), [], [], [], [], []
80
+
81
+ # print(preds.tolist())
82
+ # print(labels.tolist())
83
+ avg_loss = round(np.sum(losses) / len(losses), 4)
84
+ avg_accuracy = round(accuracy_score(new_labels, new_preds) * 100, 2)
85
+ if args.dataset_name in ['IEMOCAP', 'MELD', 'EmoryNLP']:
86
+ avg_fscore = round(f1_score(new_labels, new_preds, average='weighted') * 100, 2)
87
+ f1_per_class = f1_score(new_labels, new_preds, average=None) # List of F1 scores for each class
88
+ avg_macro_fscore = round(f1_score(new_labels, new_preds, average='macro') * 100, 2)
89
+ return avg_loss, avg_accuracy, labels, preds, avg_fscore, f1_per_class, avg_macro_fscore
90
+ else:
91
+ avg_micro_fscore = round(f1_score(new_labels, new_preds, average='micro', labels=list(range(1, 7))) * 100, 2)
92
+ avg_macro_fscore = round(f1_score(new_labels, new_preds, average='macro') * 100, 2)
93
+ return avg_loss, avg_accuracy, labels, preds, avg_micro_fscore, avg_macro_fscore
94
+
95
+
96
+ def save_badcase(model, dataloader, cuda, args, speaker_vocab, label_vocab):
97
+ preds, labels = [], []
98
+ scores, vids = [], []
99
+ dialogs = []
100
+ speakers = []
101
+
102
+ model.eval()
103
+
104
+ for data in dataloader:
105
+
106
+ # text_ids, text_feature, speaker_ids, labels, umask = [d.cuda() for d in data] if cuda else data
107
+ features, label, adj,s_mask, s_mask_onehot,lengths, speaker, utterances = data
108
+ # speaker_vec = person_embed(speaker_ids, person_vec)
109
+ if cuda:
110
+ features = features.cuda()
111
+ label = label.cuda()
112
+ adj = adj.cuda()
113
+ s_mask_onehot = s_mask_onehot.cuda()
114
+ s_mask = s_mask.cuda()
115
+ lengths = lengths.cuda()
116
+
117
+
118
+ # print(speakers)
119
+ log_prob = model(features, adj,s_mask, s_mask_onehot, lengths) # (B, N, C)
120
+
121
+
122
+ label = label.cpu().numpy().tolist() # (B, N)
123
+ pred = torch.argmax(log_prob, dim = 2).cpu().numpy().tolist() # (B, N)
124
+ preds += pred
125
+ labels += label
126
+ dialogs += utterances
127
+ speakers += speaker
128
+
129
+ # finished here
130
+
131
+ if preds != []:
132
+ new_preds = []
133
+ new_labels = []
134
+ for i,label in enumerate(labels):
135
+ for j,l in enumerate(label):
136
+ if l != -1:
137
+ new_labels.append(l)
138
+ new_preds.append(preds[i][j])
139
+ else:
140
+ return
141
+
142
+ cases = []
143
+ for i,d in enumerate(dialogs):
144
+ case = []
145
+ for j,u in enumerate(d):
146
+ case.append({
147
+ 'text': u,
148
+ 'speaker': speaker_vocab['itos'][speakers[i][j]],
149
+ 'label': label_vocab['itos'][labels[i][j]] if labels[i][j] != -1 else 'none',
150
+ 'pred': label_vocab['itos'][preds[i][j]]
151
+ })
152
+ cases.append(case)
153
+
154
+ with open('badcase/%s.json'%(args.dataset_name), 'w', encoding='utf-8') as f:
155
+ json.dump(cases,f)
156
+
157
+ # print(preds.tolist())
158
+ # print(labels.tolist())
159
+ avg_accuracy = round(accuracy_score(new_labels, new_preds) * 100, 2)
160
+ if args.dataset_name in ['IEMOCAP', 'MELD', 'EmoryNLP']:
161
+ avg_fscore = round(f1_score(new_labels, new_preds, average='weighted') * 100, 2)
162
+ print('badcase saved')
163
+ print('test_f1', avg_fscore)
164
+ return
165
+ else:
166
+ avg_micro_fscore = round(f1_score(new_labels, new_preds, average='micro', labels=list(range(1, 7))) * 100, 2)
167
+ avg_macro_fscore = round(f1_score(new_labels, new_preds, average='macro') * 100, 2)
168
+ print('badcase saved')
169
+ print('test_micro_f1', avg_micro_fscore)
170
+ print('test_macro_f1', avg_macro_fscore)
171
+ return
utils.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+
4
+ def person_embed(speaker_ids, person_vec):
5
+ '''
6
+
7
+ :param speaker_ids: torch.Tensor ( T, B)
8
+ :param person_vec: numpy array (num_speakers, 100)
9
+ :return:
10
+ speaker_vec: torch.Tensor (T, B, D)
11
+ '''
12
+ speaker_vec = []
13
+ for t in speaker_ids:
14
+ speaker_vec.append([person_vec[int(i)].tolist() if i != -1 else [0] * 100 for i in t])
15
+ speaker_vec = torch.FloatTensor(speaker_vec)
16
+ return speaker_vec