didididadada commited on
Commit
c196cad
·
verified ·
1 Parent(s): 4688646

Delete collate.py

Browse files
Files changed (1) hide show
  1. collate.py +0 -105
collate.py DELETED
@@ -1,105 +0,0 @@
1
- import os
2
- import torch
3
- from copy import deepcopy
4
- from argparse import ArgumentParser
5
- from typing import List, Tuple, Dict
6
-
7
- class Simple_Collator:
8
-
9
- @staticmethod
10
- def add_args(parent_parser: ArgumentParser) -> ArgumentParser:
11
- parser = parent_parser.add_argument_group('Data Collator Config & Hyperparameter.')
12
-
13
- parser.add_argument('--max_len', default = 256, type = int) # max length of sequence
14
- parser.add_argument('--ignore_label', default = -100, type = int) # pytorch standard ignore_label: -100
15
- parser.add_argument('--split_aa_num', default = 3, type = int) # new tokenizer split amino acid number
16
-
17
- parser.add_argument('--truncation', default = True, type = bool)
18
- parser.add_argument('--truncation_mode', default = 'cut', type = str, choices=['window', 'cut'])
19
-
20
- parser.add_argument('--padding', default = True)
21
- parser.add_argument('--padding_token', default = '[PAD]', type = str)
22
-
23
- return parent_parser
24
-
25
- def __init__(self, tokenizer, args) -> None:
26
-
27
- self.tokenizer = tokenizer # get the tokenizer
28
- self.max_len = args.max_len
29
- self.ignore_label = args.ignore_label
30
- self.split_aa_num = args.split_aa_num
31
-
32
- # truncation, padding, mask
33
- assert args.truncation_mode in ['window', 'cut'], "truncate mode must be 'window' or 'cut'."
34
- self.trunc = args.truncation
35
- self.trunc_mode = args.truncation_mode
36
-
37
- self.padding = args.padding
38
- self.padding_token = args.padding_token
39
-
40
- def process_tokens(self, tokens_ids: List[int]) -> Tuple[List[int], List[int]]:
41
- tokens_labels = [self.ignore_label] * len(tokens_ids)
42
- return tokens_ids, tokens_labels
43
-
44
- def pad_tokens(self,
45
- tokens_ids: List[int],
46
- tokens_labels: List[str]) -> Tuple[List[int], List[int], List[int]]:
47
-
48
- raw_len = len(tokens_ids)
49
-
50
- len_diff = self.max_len - (raw_len % self.max_len)
51
- tokens_ids += [self.tokenizer.encode(self.padding_token)] * len_diff
52
- tokens_labels += [self.ignore_label] * len_diff
53
- tokens_attn_mask = [1] * raw_len + [0] * len_diff
54
-
55
- return tokens_ids, tokens_labels, tokens_attn_mask
56
-
57
- def trunc_tokens(self, data: list) -> List[list]:
58
-
59
- res = []
60
- tokens_len = len(data)
61
-
62
- if tokens_len <= self.max_len: return [data]
63
-
64
- if self.trunc_mode == 'window':
65
- for i in range(tokens_len - self.max_len + 1):
66
- res.append(deepcopy(data[i: i + self.max_len]))
67
- elif self.trunc_mode == 'cut':
68
- for i in range(0, tokens_len, self.max_len):
69
- res.append(deepcopy(data[i: i + self.max_len]))
70
-
71
- return res
72
-
73
- def seq2data(self, seq: str) -> Tuple[List[int], List[int], List[int]]:
74
- tokens_ids = self.tokenizer.tokenize(seq) # 1. tokenize the sequence
75
-
76
- tokens_ids, tokens_labels = self.process_tokens(tokens_ids) # 2. joint mask and change tokens and generate labels
77
-
78
- if self.padding is True:
79
- tokens_ids, tokens_labels, tokens_attn_mask = self.pad_tokens(tokens_ids, tokens_labels) # 3. padding seqs
80
-
81
- if self.trunc is True:
82
- tokens_ids, tokens_labels, tokens_attn_mask = [self.trunc_tokens(i)
83
- for i in [tokens_ids, tokens_labels, tokens_attn_mask]] # 4. truncate data
84
-
85
- return tokens_ids, tokens_labels, tokens_attn_mask
86
-
87
- def __call__(self, data, HF_dataset: bool = False) -> Dict:
88
-
89
- input_ids, labels, attn_mask = [], [], []
90
-
91
- if HF_dataset is False:
92
- if isinstance(data, str): data = [data] # process single protein sequence for testing
93
-
94
- for i in data:
95
- seq = i['seq'] if HF_dataset else i
96
- tokens_ids, tokens_labels, tokens_attn_mask = self.seq2data(seq)
97
-
98
- input_ids.extend(deepcopy(tokens_ids))
99
- labels.extend(deepcopy(tokens_labels))
100
- attn_mask.extend(deepcopy(tokens_attn_mask))
101
-
102
- return {
103
- 'input_ids': torch.tensor(input_ids),
104
- 'labels': torch.tensor(labels),
105
- 'attention_mask': torch.tensor(attn_mask)}