| |
| |
| |
|
|
| |
|
|
| |
| import torch |
|
|
| |
|
|
| |
| |
| |
|
|
| |
| |
| |
|
|
| def find_head_to_mask(heads_mask) -> int: |
| head_idx = torch.argmax(heads_mask) |
| head_idx_simple = head_idx.item() |
| return head_idx_simple |
|
|
| def commonsense_attention_mask_update(bsz, n_tokens, commonsense_matrix, attn_weights, |
| num_heads=16, specific_head=0): |
| commonsense_mask = torch.zeros( |
| ((bsz, num_heads, n_tokens, n_tokens)) |
| ) |
| attn_weights_helper = attn_weights.reshape((num_heads, bsz, n_tokens, n_tokens)) |
| zeros = torch.zeros( |
| ((bsz, n_tokens, n_tokens)) |
| ) |
| head_previous_attention_weights = attn_weights_helper[specific_head] |
| attn_weights_helper[specific_head] = zeros |
| attn_weights_helper = attn_weights_helper.reshape((bsz, num_heads, n_tokens, n_tokens)) |
| if commonsense_matrix is None: |
| |
| commonsense_matrix = torch.ones( |
| ((bsz, n_tokens, n_tokens)) |
| ) |
| commonsense_mask = commonsense_mask.reshape((num_heads, bsz, n_tokens, n_tokens)) |
| commonsense_mask[specific_head] = head_previous_attention_weights * commonsense_matrix |
| |
| commonsense_mask = commonsense_mask.reshape((bsz, num_heads, n_tokens, n_tokens)).to('cuda') |
| return attn_weights_helper + commonsense_mask |
|
|
| def convert_relations_to_binary_mask(input_relations, should_clone=True): |
| relations_binary_mask=input_relations |
| if should_clone: |
| relations_binary_mask = input_relations.clone() |
| relations_binary_mask[relations_binary_mask > 1] = 1 |
| return relations_binary_mask |
|
|
| def relation_binary_2d_to_1d(relations_binary_mask): |
| relations_binary_mask = relations_binary_mask.sum(dim=1) |
| relations_binary_mask[relations_binary_mask > 1] = 1 |
| return relations_binary_mask |
|
|
| def create_layer_with_commonsense_on_specific_head(relation_binary_mask, bsz, num_heads, specific_head=0): |
| n_tokens = relation_binary_mask.size()[-1] |
| relations_mask = torch.zeros( |
| (bsz, num_heads, n_tokens, n_tokens) |
| ) |
| layer = relations_mask.reshape((num_heads, bsz, n_tokens, n_tokens)) |
| layer[specific_head] = relation_binary_mask |
| layer = layer.reshape((bsz, num_heads, n_tokens, n_tokens)) |
| return layer |
|
|
| def update_weights_regarding_relations_on_specific_head(layer_head_mask, attn_weights, relation_inputs, bsz, num_heads, tgt_len, src_len, verbose=True): |
| |
| inverse_layer_head_mask = (layer_head_mask.view(num_heads, 1, 1) - 1) * -1 |
| |
| |
| if verbose: |
| print("==============================") |
| print('layer_head_mask.shape:', layer_head_mask.shape) |
| print('inverse_layer_head_mask.shape:', inverse_layer_head_mask.shape) |
| print('attn_weights.shape:', attn_weights.shape) |
| print('relation_inputs.shape', relation_inputs.shape) |
| print("==============================") |
| |
| |
| |
| intermediate_weights = inverse_layer_head_mask * attn_weights.view(bsz, num_heads, tgt_len, src_len) |
| relation_inputs = convert_relations_to_binary_mask(relation_inputs, should_clone=False) |
| relation_weights = layer_head_mask.view(num_heads, 1, 1) * relation_inputs.view(bsz,1,tgt_len, src_len) * attn_weights.view(bsz, num_heads, |
| tgt_len, src_len) |
| attn_weights = intermediate_weights + relation_weights |
| |
| if verbose: |
| print('attn_weights_int.shape', attn_weights.shape) |
| return attn_weights |
|
|
| """ |
| def create_commonsense_mask(self, bsz, n_tokens, commonsense_matrix, num_heads=16, specific_head=0): |
| commonsense_mask = torch.zeros( |
| ((bsz, num_heads, n_tokens, n_tokens)) |
| ) |
| if commonsense_matrix is None: |
| commonsense_matrix = torch.zeros( |
| ((bsz, n_tokens, n_tokens)) |
| ) |
| commonsense_mask = commonsense_mask.reshape((num_heads, bsz, n_tokens, n_tokens)) |
| commonsense_mask[specific_head] = commonsense_matrix |
| commonsense_mask = commonsense_mask.reshape((bsz, num_heads, n_tokens, n_tokens)) |
| return commonsense_mask |
| |
| def commonsense_attention_mask_update(self, bsz, n_tokens, commonsense_matrix, attn_weights, |
| specific_head=0): |
| num_heads = self.num_heads |
| commonsense_mask = torch.zeros( |
| ((bsz, num_heads, n_tokens, n_tokens)) |
| ) |
| attn_weights_helper = attn_weights.reshape((num_heads, bsz, n_tokens, n_tokens)) |
| zeros = torch.zeros( |
| ((bsz, n_tokens, n_tokens)) |
| ) |
| head_previous_attention_weights = attn_weights_helper[specific_head] |
| attn_weights_helper[specific_head] = zeros |
| attn_weights_helper = attn_weights_helper.reshape((bsz, num_heads, n_tokens, n_tokens)) |
| if commonsense_matrix is None: |
| # ignore is not passed (ones -> neutral since multiplication is used) |
| commonsense_matrix = torch.ones( |
| ((bsz, n_tokens, n_tokens)) |
| ) |
| commonsense_mask = commonsense_mask.reshape((num_heads, bsz, n_tokens, n_tokens)) |
| commonsense_mask[specific_head] = head_previous_attention_weights * commonsense_matrix |
| # TODO Stupid conversion |
| commonsense_mask = commonsense_mask.reshape((bsz, num_heads, n_tokens, n_tokens)).to('cuda') |
| return attn_weights_helper + commonsense_mask |
| """ |