yangheng commited on
Commit
220d808
·
verified ·
1 Parent(s): b3f6910

Upload 6 files

Browse files
config.json ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "BertForMaskedLM"
4
+ ],
5
+ "attention_probs_dropout_prob": 0.1,
6
+ "bos_token_id": 0,
7
+ "do_sample": false,
8
+ "eos_token_ids": 0,
9
+ "finetuning_task": null,
10
+ "hidden_act": "gelu",
11
+ "hidden_dropout_prob": 0.1,
12
+ "hidden_size": 768,
13
+ "id2label": {
14
+ "0": "LABEL_0",
15
+ "1": "LABEL_1"
16
+ },
17
+ "initializer_range": 0.02,
18
+ "intermediate_size": 3072,
19
+ "is_decoder": false,
20
+ "label2id": {
21
+ "LABEL_0": 0,
22
+ "LABEL_1": 1
23
+ },
24
+ "layer_norm_eps": 1e-12,
25
+ "length_penalty": 1.0,
26
+ "max_length": 10,
27
+ "max_position_embeddings": 512,
28
+ "model_type": "bert",
29
+ "num_attention_heads": 12,
30
+ "num_beams": 1,
31
+ "num_hidden_layers": 12,
32
+ "num_labels": 2,
33
+ "num_return_sequences": 1,
34
+ "num_rnn_layer": 1,
35
+ "output_attentions": false,
36
+ "output_hidden_states": false,
37
+ "output_past": true,
38
+ "pad_token_id": 0,
39
+ "pruned_heads": {},
40
+ "repetition_penalty": 1.0,
41
+ "rnn": "lstm",
42
+ "rnn_dropout": 0.0,
43
+ "rnn_hidden": 768,
44
+ "split": 10,
45
+ "temperature": 1.0,
46
+ "top_k": 50,
47
+ "top_p": 1.0,
48
+ "torchscript": false,
49
+ "type_vocab_size": 2,
50
+ "use_bfloat16": false,
51
+ "vocab_size": 69
52
+ }
omnigenome_wrapper.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # file: omnigenbench_wrapper.py
3
+ # time: 00:57 27/04/2024
4
+ # author: YANG, HENG <hy345@exeter.ac.uk> (杨恒)
5
+ # github: https://github.com/yangheng95
6
+ # huggingface: https://huggingface.co/yangheng
7
+ # google scholar: https://scholar.google.com/citations?user=NPq5a_0AAAAJ&hl=en
8
+ # Copyright (C) 2019-2024. All Rights Reserved.
9
+
10
+ import warnings
11
+
12
+ from transformers import AutoTokenizer
13
+
14
+ from omnigenbench import OmniKmersTokenizer
15
+
16
+
17
+ class Tokenizer(OmniKmersTokenizer):
18
+ def __init__(
19
+ self, base_tokenizer=None, k=3, overlap=0, max_length=512, t2u=True, **kwargs
20
+ ):
21
+ super(Tokenizer, self).__init__(base_tokenizer, t2u=t2u, **kwargs)
22
+ self.k = k
23
+ self.overlap = overlap
24
+ self.max_length = max_length
25
+ self.metadata["tokenizer_name"] = self.__class__.__name__
26
+
27
+ def __call__(self, sequence, **kwargs):
28
+ if self.u2t:
29
+ sequence = "".join([seq.replace("U", "T").upper() for seq in sequence])
30
+ if self.t2u:
31
+ sequence = "".join([seq.replace("T", "U").upper() for seq in sequence])
32
+
33
+ sequence_tokens = self.tokenize(sequence)[
34
+ : kwargs.get("max_length", self.max_length) - 2
35
+ ]
36
+ tokenized_inputs = {
37
+ "input_ids": [],
38
+ "attention_mask": [],
39
+ }
40
+ bos_id = (
41
+ self.base_tokenizer.bos_token_id
42
+ if self.base_tokenizer.bos_token_id is not None
43
+ else self.base_tokenizer.cls_token_id
44
+ )
45
+ eos_id = (
46
+ self.base_tokenizer.eos_token_id
47
+ if self.base_tokenizer.eos_token_id is not None
48
+ else self.base_tokenizer.sep_token_id
49
+ )
50
+
51
+ for tokens in sequence_tokens:
52
+ tokenized_inputs["input_ids"].append(
53
+ [bos_id] + self.base_tokenizer.convert_tokens_to_ids(tokens) + [eos_id]
54
+ )
55
+ tokenized_inputs["attention_mask"].append(
56
+ [1] * len(tokenized_inputs["input_ids"][-1])
57
+ )
58
+
59
+ for i, ids in enumerate(tokenized_inputs["input_ids"]):
60
+ if ids.count(self.base_tokenizer.unk_token_id) / len(ids) > 0.1:
61
+ warnings.warn(
62
+ f"Unknown tokens are more than 10% in the {i}th sequence, please check the tokenization process."
63
+ )
64
+ tokenized_inputs = self.base_tokenizer.pad(
65
+ tokenized_inputs,
66
+ padding="max_length",
67
+ max_length=self.max_length
68
+ if not kwargs.get("max_length", None)
69
+ else kwargs.get("max_length"),
70
+ pad_to_multiple_of=self.max_length
71
+ if not kwargs.get("max_length", None)
72
+ else kwargs.get("max_length"),
73
+ return_attention_mask=True,
74
+ return_tensors="pt",
75
+ )
76
+ return tokenized_inputs
77
+
78
+ @staticmethod
79
+ def from_pretrained(model_name_or_path, **kwargs):
80
+ self = OmniKmersTokenizer(
81
+ AutoTokenizer.from_pretrained(model_name_or_path, **kwargs)
82
+ )
83
+ return self
84
+
85
+ def tokenize(self, sequence, **kwargs):
86
+ if isinstance(sequence, str):
87
+ sequences = [sequence]
88
+ else:
89
+ sequences = sequence
90
+
91
+ sequence_tokens = []
92
+ for i in range(len(sequences)):
93
+ tokens = []
94
+ for j in range(0, len(sequences[i]), self.k - self.overlap):
95
+ tokens.append(sequences[i][j : j + self.k])
96
+
97
+ sequence_tokens.append(tokens)
98
+
99
+ return sequence_tokens
100
+
101
+ def encode(self, input_ids, **kwargs):
102
+ return self.base_tokenizer.encode(input_ids, **kwargs)
103
+
104
+ def decode(self, input_ids, **kwargs):
105
+ return self.base_tokenizer.decode(input_ids, **kwargs)
106
+
107
+ def encode_plus(self, sequence, **kwargs):
108
+ raise NotImplementedError("The encode_plus() function is not implemented yet.")
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7aca71823ab74771006be1030d9e7239220bba40a16858575929a02e6d2a7471
3
+ size 346827305
special_tokens_map.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"unk_token": "[UNK]", "sep_token": "[SEP]", "pad_token": "[PAD]", "cls_token": "[CLS]", "mask_token": "[MASK]"}
tokenizer_config.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"do_lower_case": false, "max_len": 512}
vocab.txt ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [PAD]
2
+ [UNK]
3
+ [CLS]
4
+ [SEP]
5
+ [MASK]
6
+ AAA
7
+ AAU
8
+ AAC
9
+ AAG
10
+ AUA
11
+ AUU
12
+ AUC
13
+ AUG
14
+ ACA
15
+ ACU
16
+ ACC
17
+ ACG
18
+ AGA
19
+ AGU
20
+ AGC
21
+ AGG
22
+ UAA
23
+ UAU
24
+ UAC
25
+ UAG
26
+ UUA
27
+ UUU
28
+ UUC
29
+ UUG
30
+ UCA
31
+ UCU
32
+ UCC
33
+ UCG
34
+ UGA
35
+ UGU
36
+ UGC
37
+ UGG
38
+ CAA
39
+ CAU
40
+ CAC
41
+ CAG
42
+ CUA
43
+ CUU
44
+ CUC
45
+ CUG
46
+ CCA
47
+ CCU
48
+ CCC
49
+ CCG
50
+ CGA
51
+ CGU
52
+ CGC
53
+ CGG
54
+ GAA
55
+ GAU
56
+ GAC
57
+ GAG
58
+ GUA
59
+ GUU
60
+ GUC
61
+ GUG
62
+ GCA
63
+ GCU
64
+ GCC
65
+ GCG
66
+ GGA
67
+ GGU
68
+ GGC
69
+ GGG