| #!/usr/bin/env python3 | |
| import sys | |
| filenames = sys.argv[1:] | |
| MATCH_PATTERN_1 = "# Copied from transformers.models.bart.modeling_bart._make_causal_mask" | |
| MATCH_PATTERN_2 = "def _make_causal_mask(" | |
| MATCH_PATTERN_1 = "# Copied from transformers.models.bart.modeling_bart.prepare_4d_attention_mask" | |
| MATCH_PATTERN_2 = "def prepare_4d_attention_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):" | |
| END_MATCH_PATTERN_2 = "" | |
| # MATCH_PATTERN_1 = "def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):" | |
| #MATCH_PATTERN_2 = "# create causal mask" | |
| # END_MATCH_PATTERN_2 = "def forward(" | |
| for filename in filenames: | |
| with open(filename, "r") as f: | |
| lines = f.readlines() | |
| new_lines = [] | |
| is_in_del = False | |
| for i, line in enumerate(lines): | |
| if line.strip().lstrip() == MATCH_PATTERN_1 and i < len(lines) - 1 and lines[i + 1].strip().lstrip() == MATCH_PATTERN_2: | |
| print("suh") | |
| is_in_del = True | |
| elif line.strip().lstrip() == "" and i < len(lines) - 1 and lines[i + 1].strip().lstrip() == END_MATCH_PATTERN_2: | |
| is_in_del = False | |
| if not is_in_del: | |
| new_lines.append(line) | |
| with open(filename, "w") as f: | |
| f.writelines(new_lines) | |