|
|
import os |
|
|
import torch |
|
|
from copy import deepcopy |
|
|
from argparse import ArgumentParser |
|
|
from typing import List, Tuple, Dict |
|
|
|
|
|
class Simple_Collator: |
|
|
|
|
|
@staticmethod |
|
|
def add_args(parent_parser: ArgumentParser) -> ArgumentParser: |
|
|
parser = parent_parser.add_argument_group('Data Collator Config & Hyperparameter.') |
|
|
|
|
|
parser.add_argument('--max_len', default = 256, type = int) |
|
|
parser.add_argument('--ignore_label', default = -100, type = int) |
|
|
parser.add_argument('--split_aa_num', default = 3, type = int) |
|
|
|
|
|
parser.add_argument('--truncation', default = True, type = bool) |
|
|
parser.add_argument('--truncation_mode', default = 'cut', type = str, choices=['window', 'cut']) |
|
|
|
|
|
parser.add_argument('--padding', default = True) |
|
|
parser.add_argument('--padding_token', default = '[PAD]', type = str) |
|
|
|
|
|
return parent_parser |
|
|
|
|
|
def __init__(self, tokenizer, args) -> None: |
|
|
|
|
|
self.tokenizer = tokenizer |
|
|
self.max_len = args.max_len |
|
|
self.ignore_label = args.ignore_label |
|
|
self.split_aa_num = args.split_aa_num |
|
|
|
|
|
|
|
|
assert args.truncation_mode in ['window', 'cut'], "truncate mode must be 'window' or 'cut'." |
|
|
self.trunc = args.truncation |
|
|
self.trunc_mode = args.truncation_mode |
|
|
|
|
|
self.padding = args.padding |
|
|
self.padding_token = args.padding_token |
|
|
|
|
|
def process_tokens(self, tokens_ids: List[int]) -> Tuple[List[int], List[int]]: |
|
|
tokens_labels = [self.ignore_label] * len(tokens_ids) |
|
|
return tokens_ids, tokens_labels |
|
|
|
|
|
def pad_tokens(self, |
|
|
tokens_ids: List[int], |
|
|
tokens_labels: List[str]) -> Tuple[List[int], List[int], List[int]]: |
|
|
|
|
|
raw_len = len(tokens_ids) |
|
|
|
|
|
len_diff = self.max_len - (raw_len % self.max_len) |
|
|
tokens_ids += [self.tokenizer.encode(self.padding_token)] * len_diff |
|
|
tokens_labels += [self.ignore_label] * len_diff |
|
|
tokens_attn_mask = [1] * raw_len + [0] * len_diff |
|
|
|
|
|
return tokens_ids, tokens_labels, tokens_attn_mask |
|
|
|
|
|
def trunc_tokens(self, data: list) -> List[list]: |
|
|
|
|
|
res = [] |
|
|
tokens_len = len(data) |
|
|
|
|
|
if tokens_len <= self.max_len: return [data] |
|
|
|
|
|
if self.trunc_mode == 'window': |
|
|
for i in range(tokens_len - self.max_len + 1): |
|
|
res.append(deepcopy(data[i: i + self.max_len])) |
|
|
elif self.trunc_mode == 'cut': |
|
|
for i in range(0, tokens_len, self.max_len): |
|
|
res.append(deepcopy(data[i: i + self.max_len])) |
|
|
|
|
|
return res |
|
|
|
|
|
def seq2data(self, seq: str) -> Tuple[List[int], List[int], List[int]]: |
|
|
tokens_ids = self.tokenizer.tokenize(seq) |
|
|
|
|
|
tokens_ids, tokens_labels = self.process_tokens(tokens_ids) |
|
|
|
|
|
if self.padding is True: |
|
|
tokens_ids, tokens_labels, tokens_attn_mask = self.pad_tokens(tokens_ids, tokens_labels) |
|
|
|
|
|
if self.trunc is True: |
|
|
tokens_ids, tokens_labels, tokens_attn_mask = [self.trunc_tokens(i) |
|
|
for i in [tokens_ids, tokens_labels, tokens_attn_mask]] |
|
|
|
|
|
return tokens_ids, tokens_labels, tokens_attn_mask |
|
|
|
|
|
def __call__(self, data, HF_dataset: bool = False) -> Dict: |
|
|
|
|
|
input_ids, labels, attn_mask = [], [], [] |
|
|
|
|
|
if HF_dataset is False: |
|
|
if isinstance(data, str): data = [data] |
|
|
|
|
|
for i in data: |
|
|
seq = i['seq'] if HF_dataset else i |
|
|
tokens_ids, tokens_labels, tokens_attn_mask = self.seq2data(seq) |
|
|
|
|
|
input_ids.extend(deepcopy(tokens_ids)) |
|
|
labels.extend(deepcopy(tokens_labels)) |
|
|
attn_mask.extend(deepcopy(tokens_attn_mask)) |
|
|
|
|
|
return { |
|
|
'input_ids': torch.tensor(input_ids), |
|
|
'labels': torch.tensor(labels), |
|
|
'attention_mask': torch.tensor(attn_mask)} |