mRNABERT / tokenization_mrnabert.py
Taykhoom's picture
Upload tokenization_mrnabert.py with huggingface_hub
43ed474 verified
Raw
History Blame Contribute Delete
6.25 kB
import numpy as np
from transformers import BertTokenizer
class MRNABERTTokenizer(BertTokenizer):
"""BertTokenizer extended with CDS-aware preprocessing for mRNABERT.
mRNABERT expects space-separated tokens where UTR regions use single
nucleotides and CDS regions use three-letter codons. This tokenizer adds
batch_encode_with_cds() to handle that preprocessing automatically.
Standard usage (pre-formatted strings) still works as before:
tokenizer(["A T C G ATG TTT CCC"], return_tensors="pt")
CDS-aware usage (raw sequences + CDS track):
tokenizer.batch_encode_with_cds(
["ATCGATGTTTCCC"],
cds=[np.array([0,0,0,1,0,0,1,0,0,1,0,0,0])],
return_tensors="pt",
)
"""
_SPECIAL_TOKENS = frozenset({"[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]"})
@classmethod
def _parse_tokens(cls, sequence):
"""Split a sequence string into per-position tokens.
Bracket-enclosed special tokens ([CLS], [MASK], etc.) are treated as
atomic units so the token list length matches the CDS array length.
"""
tokens = []
i = 0
n = len(sequence)
while i < n:
if sequence[i] == "[":
close = sequence.find("]", i + 1)
if close != -1:
candidate = sequence[i:close + 1]
if candidate in cls._SPECIAL_TOKENS:
tokens.append(candidate)
i = close + 1
continue
tokens.append(sequence[i])
i += 1
return tokens
@classmethod
def preprocess(cls, sequence, cds=None):
"""Convert a raw sequence and optional CDS track to a tokenizer string.
Args:
sequence: Raw nucleotide string, e.g. "ATCGATGTTTCCC".
Uses T (not U).
cds: Integer array of length len(sequence). Non-zero values mark
the start of each codon in the CDS region. If None, every
position is treated as UTR (single-character separated).
Returns:
Space-separated string ready for the tokenizer, e.g.
"A T C ATG TTT CCC" for a sequence with CDS starting at pos 3.
"""
tokens = cls._parse_tokens(sequence)
if cds is None or not np.any(cds != 0):
return " ".join(tokens)
starts = np.where(cds != 0)[0]
start = int(starts[0])
end = min(int(starts[-1]) + 3, len(tokens))
parts = []
if start > 0:
parts.append(" ".join(tokens[:start]))
cds_items = []
for j in range(start, end, 3):
codon = tokens[j:j + 3]
if len(codon) < 3 or any(t in cls._SPECIAL_TOKENS for t in codon):
cds_items.extend(codon)
else:
cds_items.append("".join(codon))
parts.append(" ".join(cds_items))
if end < len(tokens):
parts.append(" ".join(tokens[end:]))
return " ".join(parts)
@classmethod
def chunk_sequence_cds_aware(cls, sequence, cds, chunk_length):
"""Chunk a sequence while respecting codon boundaries.
Args:
sequence: Raw nucleotide string.
cds: Integer array, one entry per nucleotide. Non-zero marks
codon start positions.
chunk_length: Maximum number of nucleotides per chunk.
Returns:
List of (token_list_chunk, cds_chunk) tuples.
"""
tokens = cls._parse_tokens(sequence)
n = len(tokens)
codon_starts = set(np.where(cds != 0)[0].tolist())
if not codon_starts:
return [
(tokens[i:i + chunk_length], cds[i:i + chunk_length])
for i in range(0, max(n, 1), chunk_length)
]
chunks = []
i = 0
while i < n:
end = min(i + chunk_length, n)
while end > i and any((end - k) in codon_starts for k in (1, 2)):
end -= 1
if end == i:
end = min(i + chunk_length, n)
chunks.append((tokens[i:end], cds[i:end]))
i = end
return chunks
def batch_encode_with_cds(self, sequences, cds, max_length=None, **kwargs):
"""Encode a batch of raw sequences using CDS-aware preprocessing.
Sequences longer than max_length nucleotides are split into
CDS-boundary-aligned chunks; each chunk is encoded separately and the
caller is responsible for aggregating across chunks.
Args:
sequences: List of raw nucleotide strings (use T, not U).
cds: List of integer numpy arrays, one per sequence.
Non-zero values mark codon start positions.
max_length: Nucleotide budget per chunk (special tokens excluded).
Defaults to model_max_length - 2.
**kwargs: Forwarded to batch_encode_plus (e.g. return_tensors,
padding, add_special_tokens).
Returns:
If no sequence exceeds max_length: a standard BatchEncoding.
If any sequence is chunked: a list of BatchEncodings, one per
(sequence, chunk) pair, together with a list of chunk counts so
the caller can re-associate chunks with their source sequences.
Use chunk_counts to index into the list:
chunk_ptr = 0
for i, n_chunks in enumerate(chunk_counts):
seq_encodings = encodings[chunk_ptr:chunk_ptr + n_chunks]
chunk_ptr += n_chunks
"""
budget = (max_length or self.model_max_length) - 2
all_strings = []
chunk_counts = []
for seq, c in zip(sequences, cds):
chunks = self.chunk_sequence_cds_aware(seq, c, budget)
for token_list, c_chunk in chunks:
all_strings.append(self.preprocess("".join(token_list), c_chunk))
chunk_counts.append(len(chunks))
enc = self.batch_encode_plus(all_strings, **kwargs)
if all(n == 1 for n in chunk_counts):
return enc, chunk_counts
return enc, chunk_counts