Taykhoom commited on
Commit
6509a75
·
verified ·
1 Parent(s): 1722240

Upload folder using huggingface_hub

Browse files
Files changed (2) hide show
  1. tokenization_codonbert.py +91 -0
  2. tokenizer_config.json +4 -1
tokenization_codonbert.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ import numpy as np
3
+ from transformers import BertTokenizer
4
+
5
+
6
+ class CodonBertTokenizer(BertTokenizer):
7
+ """BertTokenizer that auto-converts nucleotide sequences to codon-level tokens.
8
+
9
+ Raw nucleotide input is normalized (T->U, uppercase, whitespace stripped),
10
+ then split into non-overlapping 3-mer codons before vocab lookup. Trailing
11
+ 1-2 nucleotides that do not form a complete codon are dropped.
12
+
13
+ eos_token is aliased to sep_token ("[SEP]") so that pooling code that
14
+ excludes both CLS and EOS/SEP positions works correctly.
15
+
16
+ Standard usage (raw nucleotides):
17
+ tokenizer("AUGAAAGGG")
18
+ tokenizer(["AUGAAAGGG", "AUGUUUCCC"], return_tensors="pt", padding=True)
19
+
20
+ CDS-aware usage (full mRNA + CDS track -> extract CDS, chunk, encode):
21
+ tokenizer.batch_encode_with_cds(
22
+ ["NNNATGAAAGGGNN"],
23
+ cds=[np.array([0,0,0,1,0,0,1,0,0,1,0,0,0,0])],
24
+ return_tensors="pt",
25
+ padding=True,
26
+ )
27
+
28
+ Works with compare_minimal_vs_mm.py --use_cds out of the box.
29
+ """
30
+
31
+ def __init__(self, *args, **kwargs):
32
+ kwargs.setdefault("eos_token", "[SEP]")
33
+ super().__init__(*args, **kwargs)
34
+
35
+ def _tokenize(self, text, split_special_tokens=False):
36
+ seq = "".join(text.split()).upper().replace("T", "U")
37
+ n = len(seq) - len(seq) % 3
38
+ return [seq[i:i + 3] for i in range(0, n, 3)]
39
+
40
+ @staticmethod
41
+ def _extract_cds(sequence, cds):
42
+ if sum(cds) == 0:
43
+ warnings.warn("No CDS found. Returning truncated sequence.")
44
+ n = len(sequence) - len(sequence) % 3
45
+ return sequence[:n]
46
+ first = int(np.argmax(cds == 1))
47
+ last = int(len(cds) - 1 - np.argmax(np.flip(cds) == 1)) + 2
48
+ proposed = sequence[first:last + 1]
49
+ if len(proposed) % 3 != 0:
50
+ warnings.warn("Irregular CDS. Returning truncated sequence.")
51
+ return proposed[:-(len(proposed) % 3)]
52
+ return proposed
53
+
54
+ def batch_encode_with_cds(self, sequences, cds_tracks, max_length=None, **kwargs):
55
+ """Encode a batch of raw mRNA sequences using CDS-aware preprocessing.
56
+
57
+ Args:
58
+ sequences: List of raw nucleotide strings.
59
+ cds_tracks: List of numpy arrays (one per sequence). Non-zero values
60
+ mark the first nucleotide of each codon in the CDS region.
61
+ max_length: Max content codon-tokens per chunk (special tokens NOT
62
+ counted). Defaults to model_max_length - 2. This matches the
63
+ convention in compare_minimal_vs_mm.py where max_length is
64
+ already adjusted for special tokens.
65
+ **kwargs: Forwarded to batch_encode_plus (e.g. return_tensors, padding).
66
+
67
+ Returns:
68
+ (BatchEncoding, chunk_counts): chunk_counts[i] is the number of
69
+ chunks produced from sequence i.
70
+ """
71
+ budget_codons = max_length or (self.model_max_length - 2)
72
+ budget_nt = budget_codons * 3
73
+
74
+ all_strings = []
75
+ chunk_counts = []
76
+
77
+ for seq, cds in zip(sequences, cds_tracks):
78
+ seq = seq.replace("T", "U").replace("t", "u").upper()
79
+ cds_seq = self._extract_cds(seq, np.asarray(cds))
80
+ n = len(cds_seq)
81
+ chunks = []
82
+ for i in range(0, max(n, 1), budget_nt):
83
+ chunk = cds_seq[i:i + budget_nt]
84
+ chunk = chunk[:len(chunk) - len(chunk) % 3]
85
+ if chunk:
86
+ chunks.append(chunk)
87
+ all_strings.extend(chunks or [""])
88
+ chunk_counts.append(len(chunks) or 1)
89
+
90
+ enc = self.batch_encode_plus(all_strings, **kwargs)
91
+ return enc, chunk_counts
tokenizer_config.json CHANGED
@@ -50,6 +50,9 @@
50
  "sep_token": "[SEP]",
51
  "strip_accents": null,
52
  "tokenize_chinese_chars": false,
53
- "tokenizer_class": "BertTokenizer",
 
 
 
54
  "unk_token": "[UNK]"
55
  }
 
50
  "sep_token": "[SEP]",
51
  "strip_accents": null,
52
  "tokenize_chinese_chars": false,
53
+ "tokenizer_class": "CodonBertTokenizer",
54
+ "auto_map": {
55
+ "AutoTokenizer": ["tokenization_codonbert.CodonBertTokenizer", null]
56
+ },
57
  "unk_token": "[UNK]"
58
  }