amaiasalvador commited on
Commit ·
3ab629c
1
Parent(s): 9f2a462
transformer no longer returns unnecessary attention weights. fix: allow backward when training ingredient decoder
Browse files- src/model.py +2 -2
- src/modules/transformer_decoder.py +8 -11
src/model.py
CHANGED
|
@@ -211,7 +211,7 @@ class InverseCookingModel(nn.Module):
|
|
| 211 |
ingr_ids[sample_mask == 0] = self.pad_value
|
| 212 |
|
| 213 |
outputs['ingr_ids'] = ingr_ids
|
| 214 |
-
outputs['ingr_probs'] = ingr_probs
|
| 215 |
|
| 216 |
mask = sample_mask
|
| 217 |
input_mask = mask.float().unsqueeze(1)
|
|
@@ -230,7 +230,7 @@ class InverseCookingModel(nn.Module):
|
|
| 230 |
ids, probs = self.recipe_decoder.sample(input_feats, input_mask, greedy, temperature, beam, img_features, 0,
|
| 231 |
last_token_value=1)
|
| 232 |
|
| 233 |
-
outputs['recipe_probs'] = probs
|
| 234 |
outputs['recipe_ids'] = ids
|
| 235 |
|
| 236 |
return outputs
|
|
|
|
| 211 |
ingr_ids[sample_mask == 0] = self.pad_value
|
| 212 |
|
| 213 |
outputs['ingr_ids'] = ingr_ids
|
| 214 |
+
outputs['ingr_probs'] = ingr_probs.data
|
| 215 |
|
| 216 |
mask = sample_mask
|
| 217 |
input_mask = mask.float().unsqueeze(1)
|
|
|
|
| 230 |
ids, probs = self.recipe_decoder.sample(input_feats, input_mask, greedy, temperature, beam, img_features, 0,
|
| 231 |
last_token_value=1)
|
| 232 |
|
| 233 |
+
outputs['recipe_probs'] = probs.data
|
| 234 |
outputs['recipe_ids'] = ids
|
| 235 |
|
| 236 |
return outputs
|
src/modules/transformer_decoder.py
CHANGED
|
@@ -161,12 +161,11 @@ class TransformerDecoderLayer(nn.Module):
|
|
| 161 |
self.last_ln = LayerNorm(self.embed_dim)
|
| 162 |
|
| 163 |
def forward(self, x, ingr_features, ingr_mask, incremental_state, img_features):
|
| 164 |
-
attn_dict = dict()
|
| 165 |
|
| 166 |
# self attention
|
| 167 |
residual = x
|
| 168 |
x = self.maybe_layer_norm(0, x, before=True)
|
| 169 |
-
x,
|
| 170 |
query=x,
|
| 171 |
key=x,
|
| 172 |
value=x,
|
|
@@ -184,7 +183,7 @@ class TransformerDecoderLayer(nn.Module):
|
|
| 184 |
# attention
|
| 185 |
if ingr_features is None:
|
| 186 |
|
| 187 |
-
x,
|
| 188 |
key=img_features,
|
| 189 |
value=img_features,
|
| 190 |
key_padding_mask=None,
|
|
@@ -192,7 +191,7 @@ class TransformerDecoderLayer(nn.Module):
|
|
| 192 |
static_kv=True,
|
| 193 |
)
|
| 194 |
elif img_features is None:
|
| 195 |
-
x,
|
| 196 |
key=ingr_features,
|
| 197 |
value=ingr_features,
|
| 198 |
key_padding_mask=ingr_mask,
|
|
@@ -206,7 +205,7 @@ class TransformerDecoderLayer(nn.Module):
|
|
| 206 |
kv = torch.cat((img_features, ingr_features), 0)
|
| 207 |
mask = torch.cat((torch.zeros(img_features.shape[1], img_features.shape[0], dtype=torch.uint8).to(device),
|
| 208 |
ingr_mask), 1)
|
| 209 |
-
x,
|
| 210 |
key=kv,
|
| 211 |
value=kv,
|
| 212 |
key_padding_mask=mask,
|
|
@@ -229,7 +228,7 @@ class TransformerDecoderLayer(nn.Module):
|
|
| 229 |
if self.use_last_ln:
|
| 230 |
x = self.last_ln(x)
|
| 231 |
|
| 232 |
-
return x
|
| 233 |
|
| 234 |
def maybe_layer_norm(self, i, x, before=False, after=False):
|
| 235 |
assert before ^ after
|
|
@@ -308,16 +307,14 @@ class DecoderTransformer(nn.Module):
|
|
| 308 |
x = x.transpose(0, 1)
|
| 309 |
|
| 310 |
for p, layer in enumerate(self.layers):
|
| 311 |
-
x
|
| 312 |
x,
|
| 313 |
ingr_features,
|
| 314 |
ingr_mask,
|
| 315 |
incremental_state,
|
| 316 |
img_features
|
| 317 |
)
|
| 318 |
-
|
| 319 |
-
attn_dict[key][p] = attn[key]
|
| 320 |
-
#attn_layers.append(attn)
|
| 321 |
# T x B x C -> B x T x C
|
| 322 |
x = x.transpose(0, 1)
|
| 323 |
|
|
@@ -387,7 +384,7 @@ class DecoderTransformer(nn.Module):
|
|
| 387 |
sampled_ids.append(predicted)
|
| 388 |
|
| 389 |
sampled_ids = torch.stack(sampled_ids[1:], 1)
|
| 390 |
-
logits = torch.stack(logits, 1)
|
| 391 |
|
| 392 |
return sampled_ids, logits
|
| 393 |
|
|
|
|
| 161 |
self.last_ln = LayerNorm(self.embed_dim)
|
| 162 |
|
| 163 |
def forward(self, x, ingr_features, ingr_mask, incremental_state, img_features):
|
|
|
|
| 164 |
|
| 165 |
# self attention
|
| 166 |
residual = x
|
| 167 |
x = self.maybe_layer_norm(0, x, before=True)
|
| 168 |
+
x, _ = self.self_attn(
|
| 169 |
query=x,
|
| 170 |
key=x,
|
| 171 |
value=x,
|
|
|
|
| 183 |
# attention
|
| 184 |
if ingr_features is None:
|
| 185 |
|
| 186 |
+
x, _ = self.cond_att(query=x,
|
| 187 |
key=img_features,
|
| 188 |
value=img_features,
|
| 189 |
key_padding_mask=None,
|
|
|
|
| 191 |
static_kv=True,
|
| 192 |
)
|
| 193 |
elif img_features is None:
|
| 194 |
+
x, _ = self.cond_att(query=x,
|
| 195 |
key=ingr_features,
|
| 196 |
value=ingr_features,
|
| 197 |
key_padding_mask=ingr_mask,
|
|
|
|
| 205 |
kv = torch.cat((img_features, ingr_features), 0)
|
| 206 |
mask = torch.cat((torch.zeros(img_features.shape[1], img_features.shape[0], dtype=torch.uint8).to(device),
|
| 207 |
ingr_mask), 1)
|
| 208 |
+
x, _ = self.cond_att(query=x,
|
| 209 |
key=kv,
|
| 210 |
value=kv,
|
| 211 |
key_padding_mask=mask,
|
|
|
|
| 228 |
if self.use_last_ln:
|
| 229 |
x = self.last_ln(x)
|
| 230 |
|
| 231 |
+
return x
|
| 232 |
|
| 233 |
def maybe_layer_norm(self, i, x, before=False, after=False):
|
| 234 |
assert before ^ after
|
|
|
|
| 307 |
x = x.transpose(0, 1)
|
| 308 |
|
| 309 |
for p, layer in enumerate(self.layers):
|
| 310 |
+
x = layer(
|
| 311 |
x,
|
| 312 |
ingr_features,
|
| 313 |
ingr_mask,
|
| 314 |
incremental_state,
|
| 315 |
img_features
|
| 316 |
)
|
| 317 |
+
|
|
|
|
|
|
|
| 318 |
# T x B x C -> B x T x C
|
| 319 |
x = x.transpose(0, 1)
|
| 320 |
|
|
|
|
| 384 |
sampled_ids.append(predicted)
|
| 385 |
|
| 386 |
sampled_ids = torch.stack(sampled_ids[1:], 1)
|
| 387 |
+
logits = torch.stack(logits, 1)
|
| 388 |
|
| 389 |
return sampled_ids, logits
|
| 390 |
|