Taykhoom commited on
Commit
e1ac8f5
·
verified ·
1 Parent(s): ef53d17

Upload tokenization_rnafm.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. tokenization_rnafm.py +58 -0
tokenization_rnafm.py CHANGED
@@ -109,3 +109,61 @@ class RnaFmTokenizer(PreTrainedTokenizer):
109
  if token_ids_1 is None:
110
  return [0] * (len(token_ids_0) + 2)
111
  return [0] * (len(token_ids_0) + 2) + [0] * (len(token_ids_1) + 2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  if token_ids_1 is None:
110
  return [0] * (len(token_ids_0) + 2)
111
  return [0] * (len(token_ids_0) + 2) + [0] * (len(token_ids_1) + 2)
112
+
113
+ @staticmethod
114
+ def _extract_cds(sequence, cds):
115
+ """Extract CDS region from a sequence, trimmed to a multiple of 3."""
116
+ import numpy as np
117
+ if sum(cds) == 0:
118
+ return sequence[:len(sequence) - (len(sequence) % 3)]
119
+ first = int(np.argmax(cds == 1))
120
+ last = int(len(cds) - 1 - np.argmax(np.flip(cds) == 1)) + 2
121
+ region = sequence[first:last + 1]
122
+ if len(region) % 3 != 0:
123
+ region = region[:-(len(region) % 3)]
124
+ return region
125
+
126
+ def batch_encode_with_cds(self, sequences, cds, max_length=None, **kwargs):
127
+ """Encode sequences with CDS extraction (k_mer=3 / mRNA-FM only).
128
+
129
+ Applies T->U, extracts the CDS region, chunks to max_length nucleotides
130
+ (aligned to codon boundaries), and encodes each chunk.
131
+
132
+ Args:
133
+ sequences: List of raw nucleotide strings (T or U).
134
+ cds: List of numpy arrays marking CDS codon start positions.
135
+ max_length: Nucleotide budget per chunk (defaults to
136
+ (model_max_length - 2) * k_mer).
137
+ **kwargs: Forwarded to batch_encode_plus (e.g. return_tensors,
138
+ padding, add_special_tokens).
139
+
140
+ Returns:
141
+ Tuple of (BatchEncoding, chunk_counts) where chunk_counts[i] is the
142
+ number of chunks produced for sequences[i].
143
+ """
144
+ if self.k_mer != 3:
145
+ raise ValueError("batch_encode_with_cds requires k_mer=3 (mRNA-FM tokenizer)")
146
+
147
+ budget = max_length if max_length is not None else (self.model_max_length - 2) * self.k_mer
148
+ budget = (budget // self.k_mer) * self.k_mer
149
+
150
+ all_chunks = []
151
+ chunk_counts = []
152
+
153
+ for seq, c in zip(sequences, cds):
154
+ seq = seq.replace("T", "U").replace("t", "u")
155
+ seq = self._extract_cds(seq, c)
156
+ raw_chunks = [seq[i:i + budget] for i in range(0, max(len(seq), 1), budget)]
157
+ chunks = []
158
+ for chunk in raw_chunks:
159
+ if len(chunk) % self.k_mer != 0:
160
+ chunk = chunk[:-(len(chunk) % self.k_mer)]
161
+ if chunk:
162
+ chunks.append(chunk)
163
+ if not chunks:
164
+ chunks = ["AUG"]
165
+ all_chunks.extend(chunks)
166
+ chunk_counts.append(len(chunks))
167
+
168
+ enc = self.batch_encode_plus(all_chunks, **kwargs)
169
+ return enc, chunk_counts