lorenzozan commited on
Commit
53a92af
·
verified ·
1 Parent(s): b1ab244

Update modeling_me2bert.py

Browse files
Files changed (1) hide show
  1. modeling_me2bert.py +229 -243
modeling_me2bert.py CHANGED
@@ -1,243 +1,229 @@
1
- from transformers import PretrainedConfig
2
- from transformers import PreTrainedModel
3
- from transformers import AutoModel
4
- import torch
5
- from torch.autograd import Function
6
-
7
-
8
- class ReverseLayerF(Function):
9
-
10
- @staticmethod
11
- def forward(ctx, x, alpha):
12
- ctx.alpha = alpha
13
-
14
- return x.view_as(x)
15
-
16
- @staticmethod
17
- def backward(ctx, grad_output):
18
- output = grad_output.neg() * ctx.alpha
19
-
20
- return output, None
21
-
22
-
23
- class FFClassifier(torch.nn.Module):
24
-
25
- def __init__(self, input_dim, hidden_dim, n_classes, dropout=0.0):
26
- super(FFClassifier, self).__init__()
27
-
28
- self.model = torch.nn.Sequential(
29
- torch.nn.Linear(input_dim, hidden_dim),
30
- torch.nn.BatchNorm1d(hidden_dim), torch.nn.ReLU(True),
31
- torch.nn.Dropout(dropout), torch.nn.Linear(hidden_dim, n_classes))
32
-
33
- def forward(self, input):
34
- return self.model(input)
35
-
36
-
37
- class Encoder(torch.nn.Module):
38
-
39
- def __init__(self, input_dim, hidden_dim, latent_dim):
40
- super(Encoder, self).__init__()
41
- self.fc1 = torch.nn.Linear(input_dim, hidden_dim, bias=True)
42
- self.fc2 = torch.nn.Linear(hidden_dim, latent_dim, bias=True)
43
- self.prelu = torch.nn.PReLU()
44
-
45
- def forward(self, x):
46
- x = self.prelu(self.fc1(x))
47
- x = self.fc2(x)
48
- return x
49
-
50
-
51
- class Decoder(torch.nn.Module):
52
- def __init__(self, latent_dim, hidden_dim, output_dim):
53
- super(Decoder, self).__init__()
54
- self.fc1 = torch.nn.Linear(latent_dim, hidden_dim, bias=True)
55
- self.fc2 = torch.nn.Linear(hidden_dim, output_dim, bias=True)
56
- self.prelu = torch.nn.PReLU()
57
-
58
- def forward(self, x):
59
- x = self.prelu(self.fc1(x))
60
- return self.fc2(x)
61
-
62
-
63
- class AutoEncoder(torch.nn.Module):
64
- def __init__(self, input_dim, hidden_dim, latent_dim):
65
- super(AutoEncoder, self).__init__()
66
- self.encoder = Encoder(input_dim, hidden_dim, latent_dim)
67
- self.layer_norm = torch.nn.LayerNorm(latent_dim)
68
- self.decoder = Decoder(latent_dim, hidden_dim, input_dim)
69
-
70
- def forward(self, x):
71
- encoded = self.encoder(x)
72
- encoded = self.layer_norm(encoded)
73
- decoded = self.decoder(encoded)
74
- decoded = decoded
75
- return encoded, decoded
76
-
77
-
78
- class ME2BERTConfig(PretrainedConfig):
79
- model_type = "me2bert"
80
- architectures = ["ME2BERT"]
81
-
82
- def __init__(
83
- self,
84
- has_gate: bool = True, has_trans=True, **kwargs
85
- ):
86
- super().__init__(**kwargs)
87
- self.has_gate = has_gate
88
- self.has_trans = has_trans
89
- self.pretrained_model_name = 'bert-base-uncased'
90
-
91
-
92
- class GatedCombination(torch.nn.Module):
93
- def __init__(self, embedding_dim):
94
- super(GatedCombination, self).__init__()
95
- self.embedding_dim = embedding_dim
96
-
97
- self.forget_gate = torch.nn.Linear(embedding_dim, embedding_dim)
98
- self.input_gate = torch.nn.Linear(embedding_dim, embedding_dim)
99
- self.output_gate = torch.nn.Linear(embedding_dim, embedding_dim)
100
-
101
- self.sigmoid = torch.nn.Sigmoid()
102
- self.tanh = torch.nn.Tanh()
103
-
104
- def forward(self, frozen_output, finetuned_output):
105
- forget_gate = self.sigmoid(self.forget_gate(frozen_output))
106
- input_gate = self.sigmoid(self.input_gate(finetuned_output))
107
-
108
- combined = forget_gate * frozen_output + input_gate * finetuned_output
109
-
110
- output_gate = self.sigmoid(self.output_gate(combined))
111
-
112
- gated_output = output_gate * self.tanh(combined)
113
-
114
- return gated_output
115
-
116
-
117
- class ME2BERT(PreTrainedModel):
118
- config_class = ME2BERTConfig
119
-
120
- def __init__(
121
- self,
122
- config: ME2BERTConfig = None):
123
- if config is None:
124
- config = ME2BERTConfig()
125
-
126
- super().__init__(config)
127
- self.n_mf_classes = 5
128
- self.n_domain_classes = 2
129
- pretrained_model_name = config.pretrained_model_name
130
- self.has_gate = config.has_gate
131
- self.has_trans = config.has_trans
132
- self.emotion_labels = [0, 0, 0, 0, 0]
133
- self.feature = AutoModel.from_pretrained(pretrained_model_name)
134
- self.bert_frozen = AutoModel.from_pretrained(pretrained_model_name)
135
-
136
- for param in self.bert_frozen.parameters():
137
- param.requires_grad = False
138
-
139
- self.embedding_dim = self.feature.config.hidden_size
140
- latent_dim = 128
141
- self.emotion_dim = 5
142
-
143
- self.gated_combination = (
144
- GatedCombination(embedding_dim=self.embedding_dim)
145
- )
146
-
147
- self.trans_module = (
148
- AutoEncoder(self.embedding_dim, 256, latent_dim))
149
-
150
- initial_dim = self.embedding_dim + self.n_domain_classes + self.emotion_dim
151
-
152
- self.mf_classifier = FFClassifier(
153
- initial_dim, latent_dim, self.n_mf_classes, .0
154
- )
155
-
156
- self.domain_classifier = FFClassifier(
157
- self.embedding_dim, latent_dim, self.n_domain_classes,
158
-
159
- )
160
-
161
- def gen_feature_embeddings(self, input_ids, attention_mask):
162
- feature = self.feature(input_ids=input_ids, attention_mask=attention_mask)
163
- return feature.last_hidden_state, feature.pooler_output
164
-
165
- def forward(self,
166
- input_ids,
167
- attention_mask, return_dict=False):
168
-
169
- _, pooler_output = self.gen_feature_embeddings(
170
- input_ids, attention_mask)
171
-
172
- with torch.no_grad():
173
- frozen_output = self.bert_frozen(input_ids=input_ids, attention_mask=attention_mask)
174
-
175
- frozen_output = frozen_output.pooler_output
176
-
177
- device = pooler_output.device
178
- rec_embeddings = None
179
- if self.has_trans:
180
- rec_embeddings = pooler_output
181
- _, pooler_output = self.trans_module(rec_embeddings)
182
- if self.has_gate:
183
- gated_output = self.gated_combination(frozen_output, pooler_output)
184
- else:
185
- gated_output = pooler_output
186
- else:
187
- gated_output = pooler_output
188
-
189
- domain_labels = torch.zeros(gated_output.shape[0]).long().to(device)
190
- domain_feature = torch.nn.functional.one_hot(
191
- domain_labels, num_classes=self.n_domain_classes).squeeze(1)
192
-
193
- emotion_features = None
194
- if self.emotion_labels is not None:
195
- if isinstance(self.emotion_labels, list):
196
- emotion_tensor = torch.tensor(self.emotion_labels, dtype=torch.float32)
197
- emotion_features = emotion_tensor.repeat(gated_output.shape[0], 1)
198
- else:
199
- emotion_features = torch.nn.functional.one_hot(
200
- self.emotion_labels.long(), num_classes=self.emotion_dim
201
- ).squeeze(1)
202
-
203
- if emotion_features is not None:
204
- emotion_features = emotion_features[:gated_output.shape[0], :]
205
- class_output = torch.cat((gated_output, domain_feature, emotion_features), dim=1)
206
-
207
- else:
208
- emotion_features = torch.zeros(gated_output.shape[0], self.emotion_dim).to(device)
209
- class_output = torch.cat((gated_output, domain_feature, emotion_features), dim=1)
210
-
211
- class_output = torch.sigmoid(self.mf_classifier(class_output))
212
- if return_dict:
213
- mft_dimensions = [
214
- 'CH',
215
- 'FC',
216
- 'LB',
217
- 'AS',
218
- 'PD'
219
- ]
220
-
221
- result_list = []
222
- for i in range(class_output.shape[0]):
223
- row_scores = [round(score.item(), 5) for score in class_output[i]]
224
- row_dict = dict(zip(mft_dimensions, row_scores))
225
- result_list.append(row_dict)
226
- return result_list
227
-
228
- return class_output
229
-
230
- @classmethod
231
- def from_pretrained(cls, pretrained_model_name_or_path, *model_args, config=None, **kwargs):
232
- if config is None:
233
- try:
234
- config = cls.config_class.from_pretrained(pretrained_model_name_or_path, **kwargs)
235
- except (OSError, ValueError):
236
- config = cls.config_class()
237
-
238
- return super().from_pretrained(
239
- pretrained_model_name_or_path,
240
- *model_args,
241
- config=config,
242
- **kwargs,
243
- )
 
1
+ from transformers import PretrainedConfig
2
+ from transformers import PreTrainedModel
3
+ from transformers import AutoModel
4
+ import torch
5
+ from torch.autograd import Function
6
+
7
+
8
+ class ReverseLayerF(Function):
9
+
10
+ @staticmethod
11
+ def forward(ctx, x, alpha):
12
+ ctx.alpha = alpha
13
+
14
+ return x.view_as(x)
15
+
16
+ @staticmethod
17
+ def backward(ctx, grad_output):
18
+ output = grad_output.neg() * ctx.alpha
19
+
20
+ return output, None
21
+
22
+
23
+ class FFClassifier(torch.nn.Module):
24
+
25
+ def __init__(self, input_dim, hidden_dim, n_classes, dropout=0.0):
26
+ super(FFClassifier, self).__init__()
27
+
28
+ self.model = torch.nn.Sequential(
29
+ torch.nn.Linear(input_dim, hidden_dim),
30
+ torch.nn.BatchNorm1d(hidden_dim), torch.nn.ReLU(True),
31
+ torch.nn.Dropout(dropout), torch.nn.Linear(hidden_dim, n_classes))
32
+
33
+ def forward(self, input):
34
+ return self.model(input)
35
+
36
+
37
+ class Encoder(torch.nn.Module):
38
+
39
+ def __init__(self, input_dim, hidden_dim, latent_dim):
40
+ super(Encoder, self).__init__()
41
+ self.fc1 = torch.nn.Linear(input_dim, hidden_dim, bias=True)
42
+ self.fc2 = torch.nn.Linear(hidden_dim, latent_dim, bias=True)
43
+ self.prelu = torch.nn.PReLU()
44
+
45
+ def forward(self, x):
46
+ x = self.prelu(self.fc1(x))
47
+ x = self.fc2(x)
48
+ return x
49
+
50
+
51
+ class Decoder(torch.nn.Module):
52
+ def __init__(self, latent_dim, hidden_dim, output_dim):
53
+ super(Decoder, self).__init__()
54
+ self.fc1 = torch.nn.Linear(latent_dim, hidden_dim, bias=True)
55
+ self.fc2 = torch.nn.Linear(hidden_dim, output_dim, bias=True)
56
+ self.prelu = torch.nn.PReLU()
57
+
58
+ def forward(self, x):
59
+ x = self.prelu(self.fc1(x))
60
+ return self.fc2(x)
61
+
62
+
63
+ class AutoEncoder(torch.nn.Module):
64
+ def __init__(self, input_dim, hidden_dim, latent_dim):
65
+ super(AutoEncoder, self).__init__()
66
+ self.encoder = Encoder(input_dim, hidden_dim, latent_dim)
67
+ self.layer_norm = torch.nn.LayerNorm(latent_dim)
68
+ self.decoder = Decoder(latent_dim, hidden_dim, input_dim)
69
+
70
+ def forward(self, x):
71
+ encoded = self.encoder(x)
72
+ encoded = self.layer_norm(encoded)
73
+ decoded = self.decoder(encoded)
74
+ decoded = decoded
75
+ return encoded, decoded
76
+
77
+
78
+ class GatedCombination(torch.nn.Module):
79
+ def __init__(self, embedding_dim):
80
+ super(GatedCombination, self).__init__()
81
+ self.embedding_dim = embedding_dim
82
+
83
+ self.forget_gate = torch.nn.Linear(embedding_dim, embedding_dim)
84
+ self.input_gate = torch.nn.Linear(embedding_dim, embedding_dim)
85
+ self.output_gate = torch.nn.Linear(embedding_dim, embedding_dim)
86
+
87
+ self.sigmoid = torch.nn.Sigmoid()
88
+ self.tanh = torch.nn.Tanh()
89
+
90
+ def forward(self, frozen_output, finetuned_output):
91
+ forget_gate = self.sigmoid(self.forget_gate(frozen_output))
92
+ input_gate = self.sigmoid(self.input_gate(finetuned_output))
93
+
94
+ combined = forget_gate * frozen_output + input_gate * finetuned_output
95
+
96
+ output_gate = self.sigmoid(self.output_gate(combined))
97
+
98
+ gated_output = output_gate * self.tanh(combined)
99
+
100
+ return gated_output
101
+
102
+
103
+ class ME2BERT(PreTrainedModel):
104
+ config_class = ME2BERTConfig
105
+
106
+ def __init__(
107
+ self,
108
+ config: ME2BERTConfig = None):
109
+ if config is None:
110
+ config = ME2BERTConfig()
111
+
112
+ super().__init__(config)
113
+ self.n_mf_classes = 5
114
+ self.n_domain_classes = 2
115
+ pretrained_model_name = config.pretrained_model_name
116
+ self.has_gate = config.has_gate
117
+ self.has_trans = config.has_trans
118
+ self.emotion_labels = [0, 0, 0, 0, 0]
119
+ self.feature = AutoModel.from_pretrained(pretrained_model_name)
120
+ self.bert_frozen = AutoModel.from_pretrained(pretrained_model_name)
121
+
122
+ for param in self.bert_frozen.parameters():
123
+ param.requires_grad = False
124
+
125
+ self.embedding_dim = self.feature.config.hidden_size
126
+ latent_dim = 128
127
+ self.emotion_dim = 5
128
+
129
+ self.gated_combination = (
130
+ GatedCombination(embedding_dim=self.embedding_dim)
131
+ )
132
+
133
+ self.trans_module = (
134
+ AutoEncoder(self.embedding_dim, 256, latent_dim))
135
+
136
+ initial_dim = self.embedding_dim + self.n_domain_classes + self.emotion_dim
137
+
138
+ self.mf_classifier = FFClassifier(
139
+ initial_dim, latent_dim, self.n_mf_classes, .0
140
+ )
141
+
142
+ self.domain_classifier = FFClassifier(
143
+ self.embedding_dim, latent_dim, self.n_domain_classes,
144
+
145
+ )
146
+
147
+ def gen_feature_embeddings(self, input_ids, attention_mask):
148
+ feature = self.feature(input_ids=input_ids, attention_mask=attention_mask)
149
+ return feature.last_hidden_state, feature.pooler_output
150
+
151
+ def forward(self,
152
+ input_ids,
153
+ attention_mask, return_dict=False):
154
+
155
+ _, pooler_output = self.gen_feature_embeddings(
156
+ input_ids, attention_mask)
157
+
158
+ with torch.no_grad():
159
+ frozen_output = self.bert_frozen(input_ids=input_ids, attention_mask=attention_mask)
160
+
161
+ frozen_output = frozen_output.pooler_output
162
+
163
+ device = pooler_output.device
164
+ rec_embeddings = None
165
+ if self.has_trans:
166
+ rec_embeddings = pooler_output
167
+ _, pooler_output = self.trans_module(rec_embeddings)
168
+ if self.has_gate:
169
+ gated_output = self.gated_combination(frozen_output, pooler_output)
170
+ else:
171
+ gated_output = pooler_output
172
+ else:
173
+ gated_output = pooler_output
174
+
175
+ domain_labels = torch.zeros(gated_output.shape[0]).long().to(device)
176
+ domain_feature = torch.nn.functional.one_hot(
177
+ domain_labels, num_classes=self.n_domain_classes).squeeze(1)
178
+
179
+ emotion_features = None
180
+ if self.emotion_labels is not None:
181
+ if isinstance(self.emotion_labels, list):
182
+ emotion_tensor = torch.tensor(self.emotion_labels, dtype=torch.float32)
183
+ emotion_features = emotion_tensor.repeat(gated_output.shape[0], 1)
184
+ else:
185
+ emotion_features = torch.nn.functional.one_hot(
186
+ self.emotion_labels.long(), num_classes=self.emotion_dim
187
+ ).squeeze(1)
188
+
189
+ if emotion_features is not None:
190
+ emotion_features = emotion_features[:gated_output.shape[0], :]
191
+ class_output = torch.cat((gated_output, domain_feature, emotion_features), dim=1)
192
+
193
+ else:
194
+ emotion_features = torch.zeros(gated_output.shape[0], self.emotion_dim).to(device)
195
+ class_output = torch.cat((gated_output, domain_feature, emotion_features), dim=1)
196
+
197
+ class_output = torch.sigmoid(self.mf_classifier(class_output))
198
+ if return_dict:
199
+ mft_dimensions = [
200
+ 'CH',
201
+ 'FC',
202
+ 'LB',
203
+ 'AS',
204
+ 'PD'
205
+ ]
206
+
207
+ result_list = []
208
+ for i in range(class_output.shape[0]):
209
+ row_scores = [round(score.item(), 5) for score in class_output[i]]
210
+ row_dict = dict(zip(mft_dimensions, row_scores))
211
+ result_list.append(row_dict)
212
+ return result_list
213
+
214
+ return class_output
215
+
216
+ @classmethod
217
+ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, config=None, **kwargs):
218
+ if config is None:
219
+ try:
220
+ config = cls.config_class.from_pretrained(pretrained_model_name_or_path, **kwargs)
221
+ except (OSError, ValueError):
222
+ config = cls.config_class()
223
+
224
+ return super().from_pretrained(
225
+ pretrained_model_name_or_path,
226
+ *model_args,
227
+ config=config,
228
+ **kwargs,
229
+ )