Upload model.py
Browse files
model.py
CHANGED
|
@@ -66,7 +66,7 @@ class ML_BART(nn.Module):
|
|
| 66 |
super().__init__()
|
| 67 |
d_model = bartconfig.d_model
|
| 68 |
|
| 69 |
-
self.
|
| 70 |
nn.Embedding(class_num[0] + 1, d_model // 4),
|
| 71 |
nn.Embedding(class_num[1] + 1, d_model // 4)
|
| 72 |
])
|
|
@@ -91,7 +91,7 @@ class ML_BART(nn.Module):
|
|
| 91 |
emb_decoder = self.encoder(x_decoder)
|
| 92 |
else:
|
| 93 |
emb_decoder = torch.concatenate(
|
| 94 |
-
[self.
|
| 95 |
self.decoder(x_encoder)], dim=-1)
|
| 96 |
|
| 97 |
y = self.bart(inputs_embeds=emb_encoder, decoder_inputs_embeds=emb_decoder,
|
|
|
|
| 66 |
super().__init__()
|
| 67 |
d_model = bartconfig.d_model
|
| 68 |
|
| 69 |
+
self.decoder_emb2 = nn.ModuleList([
|
| 70 |
nn.Embedding(class_num[0] + 1, d_model // 4),
|
| 71 |
nn.Embedding(class_num[1] + 1, d_model // 4)
|
| 72 |
])
|
|
|
|
| 91 |
emb_decoder = self.encoder(x_decoder)
|
| 92 |
else:
|
| 93 |
emb_decoder = torch.concatenate(
|
| 94 |
+
[self.decoder_emb2[0](x_decoder[..., 0]), self.decoder_emb2[1](x_decoder[..., 1]),
|
| 95 |
self.decoder(x_encoder)], dim=-1)
|
| 96 |
|
| 97 |
y = self.bart(inputs_embeds=emb_encoder, decoder_inputs_embeds=emb_decoder,
|