Update modeling_me2bert.py
Browse files- modeling_me2bert.py +2 -2
modeling_me2bert.py
CHANGED
|
@@ -187,11 +187,11 @@ class ME2BertModel(PreTrainedModel):
|
|
| 187 |
|
| 188 |
if emotion_features is not None:
|
| 189 |
emotion_features = emotion_features[:gated_output.shape[0], :]
|
| 190 |
-
class_output = torch.cat((gated_output, domain_feature, emotion_features), dim=1)
|
| 191 |
|
| 192 |
else:
|
| 193 |
emotion_features = torch.zeros(gated_output.shape[0], self.emotion_dim).to(device)
|
| 194 |
-
class_output = torch.cat((gated_output, domain_feature, emotion_features), dim=1)
|
| 195 |
|
| 196 |
class_output = torch.sigmoid(self.mf_classifier(class_output))
|
| 197 |
if return_dict:
|
|
|
|
| 187 |
|
| 188 |
if emotion_features is not None:
|
| 189 |
emotion_features = emotion_features[:gated_output.shape[0], :]
|
| 190 |
+
class_output = torch.cat((gated_output.to(device), domain_feature.to(device), emotion_features.to(device)), dim=1)
|
| 191 |
|
| 192 |
else:
|
| 193 |
emotion_features = torch.zeros(gated_output.shape[0], self.emotion_dim).to(device)
|
| 194 |
+
class_output = torch.cat((gated_output.to(device), domain_feature.to(device), emotion_features.to(device)), dim=1)
|
| 195 |
|
| 196 |
class_output = torch.sigmoid(self.mf_classifier(class_output))
|
| 197 |
if return_dict:
|