Delete collate.py
Browse files- collate.py +0 -105
collate.py
DELETED
|
@@ -1,105 +0,0 @@
|
|
| 1 |
-
import os
|
| 2 |
-
import torch
|
| 3 |
-
from copy import deepcopy
|
| 4 |
-
from argparse import ArgumentParser
|
| 5 |
-
from typing import List, Tuple, Dict
|
| 6 |
-
|
| 7 |
-
class Simple_Collator:
|
| 8 |
-
|
| 9 |
-
@staticmethod
|
| 10 |
-
def add_args(parent_parser: ArgumentParser) -> ArgumentParser:
|
| 11 |
-
parser = parent_parser.add_argument_group('Data Collator Config & Hyperparameter.')
|
| 12 |
-
|
| 13 |
-
parser.add_argument('--max_len', default = 256, type = int) # max length of sequence
|
| 14 |
-
parser.add_argument('--ignore_label', default = -100, type = int) # pytorch standard ignore_label: -100
|
| 15 |
-
parser.add_argument('--split_aa_num', default = 3, type = int) # new tokenizer split amino acid number
|
| 16 |
-
|
| 17 |
-
parser.add_argument('--truncation', default = True, type = bool)
|
| 18 |
-
parser.add_argument('--truncation_mode', default = 'cut', type = str, choices=['window', 'cut'])
|
| 19 |
-
|
| 20 |
-
parser.add_argument('--padding', default = True)
|
| 21 |
-
parser.add_argument('--padding_token', default = '[PAD]', type = str)
|
| 22 |
-
|
| 23 |
-
return parent_parser
|
| 24 |
-
|
| 25 |
-
def __init__(self, tokenizer, args) -> None:
|
| 26 |
-
|
| 27 |
-
self.tokenizer = tokenizer # get the tokenizer
|
| 28 |
-
self.max_len = args.max_len
|
| 29 |
-
self.ignore_label = args.ignore_label
|
| 30 |
-
self.split_aa_num = args.split_aa_num
|
| 31 |
-
|
| 32 |
-
# truncation, padding, mask
|
| 33 |
-
assert args.truncation_mode in ['window', 'cut'], "truncate mode must be 'window' or 'cut'."
|
| 34 |
-
self.trunc = args.truncation
|
| 35 |
-
self.trunc_mode = args.truncation_mode
|
| 36 |
-
|
| 37 |
-
self.padding = args.padding
|
| 38 |
-
self.padding_token = args.padding_token
|
| 39 |
-
|
| 40 |
-
def process_tokens(self, tokens_ids: List[int]) -> Tuple[List[int], List[int]]:
|
| 41 |
-
tokens_labels = [self.ignore_label] * len(tokens_ids)
|
| 42 |
-
return tokens_ids, tokens_labels
|
| 43 |
-
|
| 44 |
-
def pad_tokens(self,
|
| 45 |
-
tokens_ids: List[int],
|
| 46 |
-
tokens_labels: List[str]) -> Tuple[List[int], List[int], List[int]]:
|
| 47 |
-
|
| 48 |
-
raw_len = len(tokens_ids)
|
| 49 |
-
|
| 50 |
-
len_diff = self.max_len - (raw_len % self.max_len)
|
| 51 |
-
tokens_ids += [self.tokenizer.encode(self.padding_token)] * len_diff
|
| 52 |
-
tokens_labels += [self.ignore_label] * len_diff
|
| 53 |
-
tokens_attn_mask = [1] * raw_len + [0] * len_diff
|
| 54 |
-
|
| 55 |
-
return tokens_ids, tokens_labels, tokens_attn_mask
|
| 56 |
-
|
| 57 |
-
def trunc_tokens(self, data: list) -> List[list]:
|
| 58 |
-
|
| 59 |
-
res = []
|
| 60 |
-
tokens_len = len(data)
|
| 61 |
-
|
| 62 |
-
if tokens_len <= self.max_len: return [data]
|
| 63 |
-
|
| 64 |
-
if self.trunc_mode == 'window':
|
| 65 |
-
for i in range(tokens_len - self.max_len + 1):
|
| 66 |
-
res.append(deepcopy(data[i: i + self.max_len]))
|
| 67 |
-
elif self.trunc_mode == 'cut':
|
| 68 |
-
for i in range(0, tokens_len, self.max_len):
|
| 69 |
-
res.append(deepcopy(data[i: i + self.max_len]))
|
| 70 |
-
|
| 71 |
-
return res
|
| 72 |
-
|
| 73 |
-
def seq2data(self, seq: str) -> Tuple[List[int], List[int], List[int]]:
|
| 74 |
-
tokens_ids = self.tokenizer.tokenize(seq) # 1. tokenize the sequence
|
| 75 |
-
|
| 76 |
-
tokens_ids, tokens_labels = self.process_tokens(tokens_ids) # 2. joint mask and change tokens and generate labels
|
| 77 |
-
|
| 78 |
-
if self.padding is True:
|
| 79 |
-
tokens_ids, tokens_labels, tokens_attn_mask = self.pad_tokens(tokens_ids, tokens_labels) # 3. padding seqs
|
| 80 |
-
|
| 81 |
-
if self.trunc is True:
|
| 82 |
-
tokens_ids, tokens_labels, tokens_attn_mask = [self.trunc_tokens(i)
|
| 83 |
-
for i in [tokens_ids, tokens_labels, tokens_attn_mask]] # 4. truncate data
|
| 84 |
-
|
| 85 |
-
return tokens_ids, tokens_labels, tokens_attn_mask
|
| 86 |
-
|
| 87 |
-
def __call__(self, data, HF_dataset: bool = False) -> Dict:
|
| 88 |
-
|
| 89 |
-
input_ids, labels, attn_mask = [], [], []
|
| 90 |
-
|
| 91 |
-
if HF_dataset is False:
|
| 92 |
-
if isinstance(data, str): data = [data] # process single protein sequence for testing
|
| 93 |
-
|
| 94 |
-
for i in data:
|
| 95 |
-
seq = i['seq'] if HF_dataset else i
|
| 96 |
-
tokens_ids, tokens_labels, tokens_attn_mask = self.seq2data(seq)
|
| 97 |
-
|
| 98 |
-
input_ids.extend(deepcopy(tokens_ids))
|
| 99 |
-
labels.extend(deepcopy(tokens_labels))
|
| 100 |
-
attn_mask.extend(deepcopy(tokens_attn_mask))
|
| 101 |
-
|
| 102 |
-
return {
|
| 103 |
-
'input_ids': torch.tensor(input_ids),
|
| 104 |
-
'labels': torch.tensor(labels),
|
| 105 |
-
'attention_mask': torch.tensor(attn_mask)}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|