kashif HF Staff commited on
Commit
a5f56cd
·
verified ·
1 Parent(s): a955d9a

tokenizer: fix EOS append bug, decode skip_special_tokens=True, add auto_dna_tags

Browse files
Files changed (1) hide show
  1. tokenizer.py +50 -18
tokenizer.py CHANGED
@@ -8,7 +8,9 @@ Supports token_mask for Fine-grained Nucleotide Supervision (FNS):
8
  -2: padding token
9
  -1: text token (BPE)
10
  0: DNA special token (<dna>, </dna>, <oov>)
11
- 1-5: partial 6-mer (number of valid bases)
 
 
12
  6: full 6-mer
13
  """
14
 
@@ -26,6 +28,12 @@ class HybridDNATokenizer(PreTrainedTokenizer):
26
 
27
  DNA regions must be wrapped in <dna>...</dna> tags to be tokenized as 6-mers.
28
  Without tags, DNA sequences are tokenized as regular BPE text.
 
 
 
 
 
 
29
  """
30
 
31
  model_input_names = ["input_ids", "attention_mask"]
@@ -34,6 +42,7 @@ class HybridDNATokenizer(PreTrainedTokenizer):
34
  self,
35
  base_tokenizer_path: Optional[str] = None,
36
  k: int = 6,
 
37
  **kwargs
38
  ):
39
  self.k = k
@@ -63,6 +72,7 @@ class HybridDNATokenizer(PreTrainedTokenizer):
63
  )
64
 
65
  self.special_tokens = self.dna_special_tokens + [self._eos_token, self._pad_token]
 
66
 
67
  def _init_dna_vocab(self):
68
  """Initialize DNA vocabulary (special tokens + k-mers + padding for 128 alignment)."""
@@ -228,6 +238,10 @@ class HybridDNATokenizer(PreTrainedTokenizer):
228
 
229
  if remaining:
230
  padding_needed = k - len(remaining)
 
 
 
 
231
  padded = remaining + 'A' * padding_needed
232
 
233
  if is_valid_kmer(padded):
@@ -265,8 +279,13 @@ class HybridDNATokenizer(PreTrainedTokenizer):
265
  text: str,
266
  add_special_tokens: bool = False,
267
  return_token_mask: bool = False,
 
268
  **kwargs
269
  ) -> Union[List[int], Tuple[List[int], List[int]]]:
 
 
 
 
270
  segments = self._split_by_dna_tags(text)
271
 
272
  token_ids = []
@@ -309,10 +328,11 @@ class HybridDNATokenizer(PreTrainedTokenizer):
309
  if return_token_mask:
310
  token_mask.extend([-1] * len(base_ids))
311
 
312
- if add_special_tokens and self.eos_token_id is not None:
313
- token_ids.append(self.eos_token_id)
314
- if return_token_mask:
315
- token_mask.append(-1)
 
316
 
317
  if return_token_mask:
318
  return token_ids, token_mask
@@ -357,7 +377,14 @@ class HybridDNATokenizer(PreTrainedTokenizer):
357
  i += 1
358
 
359
  elif tid in self.dna_id_to_token:
360
- if not skip_special_tokens:
 
 
 
 
 
 
 
361
  parts.append(self.dna_id_to_token[tid])
362
  i += 1
363
 
@@ -400,6 +427,7 @@ class HybridDNATokenizer(PreTrainedTokenizer):
400
  max_length: Optional[int] = None,
401
  return_tensors: Optional[str] = None,
402
  return_token_mask: bool = False,
 
403
  **kwargs
404
  ) -> Dict[str, Any]:
405
  is_batch = isinstance(text, list)
@@ -410,11 +438,11 @@ class HybridDNATokenizer(PreTrainedTokenizer):
410
 
411
  for t in texts:
412
  if return_token_mask:
413
- ids, mask = self.encode(t, add_special_tokens=add_special_tokens, return_token_mask=True)
414
  all_ids.append(ids)
415
  all_masks.append(mask)
416
  else:
417
- ids = self.encode(t, add_special_tokens=add_special_tokens, return_token_mask=False)
418
  all_ids.append(ids)
419
 
420
  if padding:
@@ -496,6 +524,7 @@ class HybridDNATokenizer(PreTrainedTokenizer):
496
  "dna_start_id": self.dna_start_id,
497
  "dna_vocab_size": self.dna_vocab_size,
498
  "dna_special_tokens": self.dna_special_tokens,
 
499
  }
500
 
501
  dna_config_path = os.path.join(save_directory, "dna_config.json")
@@ -517,6 +546,7 @@ class HybridDNATokenizer(PreTrainedTokenizer):
517
  "AutoTokenizer": ["tokenizer.HybridDNATokenizer", None]
518
  },
519
  "k": self.k,
 
520
  })
521
 
522
  with open(config_path, "w", encoding="utf-8") as f:
@@ -533,19 +563,21 @@ class HybridDNATokenizer(PreTrainedTokenizer):
533
 
534
  @classmethod
535
  def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs):
 
 
 
536
  dna_config_path = os.path.join(pretrained_model_name_or_path, "dna_config.json")
 
537
 
538
  if os.path.exists(dna_config_path):
539
  with open(dna_config_path, "r") as f:
540
  dna_config = json.load(f)
541
  k = dna_config.get("k", 6)
542
- else:
543
- config_path = os.path.join(pretrained_model_name_or_path, "tokenizer_config.json")
544
- if os.path.exists(config_path):
545
- with open(config_path, "r") as f:
546
- config = json.load(f)
547
- k = config.get("k", 6)
548
- else:
549
- k = 6
550
-
551
- return cls(base_tokenizer_path=pretrained_model_name_or_path, k=k, **kwargs)
 
8
  -2: padding token
9
  -1: text token (BPE)
10
  0: DNA special token (<dna>, </dna>, <oov>)
11
+ 1-5: partial 6-mer token valid_length real bases at positions [0, valid_length),
12
+ right-padded with 'A' at positions [valid_length, k) so loss can supervise
13
+ positions 0..valid_len-1 via pos_mask = (valid_len > pos)
14
  6: full 6-mer
15
  """
16
 
 
28
 
29
  DNA regions must be wrapped in <dna>...</dna> tags to be tokenized as 6-mers.
30
  Without tags, DNA sequences are tokenized as regular BPE text.
31
+
32
+ For pure-DNA input (no metadata tokens), pass auto_dna_tags=True to have
33
+ <dna>...</dna> tags added automatically when they are absent. Do NOT set
34
+ this if the input may contain BPE metadata such as species tags
35
+ (<fungi_species> etc.) — those must appear outside <dna>...</dna> and would
36
+ be incorrectly k-mer encoded if auto-wrapping fired.
37
  """
38
 
39
  model_input_names = ["input_ids", "attention_mask"]
 
42
  self,
43
  base_tokenizer_path: Optional[str] = None,
44
  k: int = 6,
45
+ auto_dna_tags: bool = False,
46
  **kwargs
47
  ):
48
  self.k = k
 
72
  )
73
 
74
  self.special_tokens = self.dna_special_tokens + [self._eos_token, self._pad_token]
75
+ self.auto_dna_tags = auto_dna_tags
76
 
77
  def _init_dna_vocab(self):
78
  """Initialize DNA vocabulary (special tokens + k-mers + padding for 128 alignment)."""
 
238
 
239
  if remaining:
240
  padding_needed = k - len(remaining)
241
+ # Right-pad with A: real bases occupy positions [0, valid_length).
242
+ # The hybrid BP loss supervises positions 0..valid_len-1 via
243
+ # pos_mask = (valid_len > pos)
244
+ # so padding must be at the END, not the start.
245
  padded = remaining + 'A' * padding_needed
246
 
247
  if is_valid_kmer(padded):
 
279
  text: str,
280
  add_special_tokens: bool = False,
281
  return_token_mask: bool = False,
282
+ auto_dna_tags: Optional[bool] = None,
283
  **kwargs
284
  ) -> Union[List[int], Tuple[List[int], List[int]]]:
285
+ use_auto = self.auto_dna_tags if auto_dna_tags is None else auto_dna_tags
286
+ if use_auto and '<dna>' not in text:
287
+ text = f'<dna>{text}</dna>'
288
+
289
  segments = self._split_by_dna_tags(text)
290
 
291
  token_ids = []
 
328
  if return_token_mask:
329
  token_mask.extend([-1] * len(base_ids))
330
 
331
+ # Do NOT append EOS when add_special_tokens=True. Qwen3 doesn't add
332
+ # BOS/EOS either, and appending EOS here breaks lighteval's
333
+ # tok_encode_pair: it relies on
334
+ # len(encode(ctx)) + len(encode(answer)) == len(encode(ctx + answer))
335
+ # which the extra EOS violates by shifting the split by 1.
336
 
337
  if return_token_mask:
338
  return token_ids, token_mask
 
377
  i += 1
378
 
379
  elif tid in self.dna_id_to_token:
380
+ # This branch handles k-mer tokens that appear without a <dna>
381
+ # wrapper — the common generation case where <dna> was in the
382
+ # prompt but only the generated portion is being decoded.
383
+ # K-mer tokens are content, not special tokens, so always decode
384
+ # them. Only drop true DNA special tokens (<dna>, </dna>, <oov>)
385
+ # when skip_special_tokens=True.
386
+ is_dna_special = tid in (self.dna_begin_token_id, self.dna_end_token_id, self.oov_token_id)
387
+ if not (skip_special_tokens and is_dna_special):
388
  parts.append(self.dna_id_to_token[tid])
389
  i += 1
390
 
 
427
  max_length: Optional[int] = None,
428
  return_tensors: Optional[str] = None,
429
  return_token_mask: bool = False,
430
+ auto_dna_tags: Optional[bool] = None,
431
  **kwargs
432
  ) -> Dict[str, Any]:
433
  is_batch = isinstance(text, list)
 
438
 
439
  for t in texts:
440
  if return_token_mask:
441
+ ids, mask = self.encode(t, add_special_tokens=add_special_tokens, return_token_mask=True, auto_dna_tags=auto_dna_tags)
442
  all_ids.append(ids)
443
  all_masks.append(mask)
444
  else:
445
+ ids = self.encode(t, add_special_tokens=add_special_tokens, return_token_mask=False, auto_dna_tags=auto_dna_tags)
446
  all_ids.append(ids)
447
 
448
  if padding:
 
524
  "dna_start_id": self.dna_start_id,
525
  "dna_vocab_size": self.dna_vocab_size,
526
  "dna_special_tokens": self.dna_special_tokens,
527
+ "auto_dna_tags": self.auto_dna_tags,
528
  }
529
 
530
  dna_config_path = os.path.join(save_directory, "dna_config.json")
 
546
  "AutoTokenizer": ["tokenizer.HybridDNATokenizer", None]
547
  },
548
  "k": self.k,
549
+ "auto_dna_tags": self.auto_dna_tags,
550
  })
551
 
552
  with open(config_path, "w", encoding="utf-8") as f:
 
563
 
564
  @classmethod
565
  def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs):
566
+ k = 6
567
+ auto_dna_tags = False
568
+
569
  dna_config_path = os.path.join(pretrained_model_name_or_path, "dna_config.json")
570
+ tok_config_path = os.path.join(pretrained_model_name_or_path, "tokenizer_config.json")
571
 
572
  if os.path.exists(dna_config_path):
573
  with open(dna_config_path, "r") as f:
574
  dna_config = json.load(f)
575
  k = dna_config.get("k", 6)
576
+ auto_dna_tags = dna_config.get("auto_dna_tags", False)
577
+ elif os.path.exists(tok_config_path):
578
+ with open(tok_config_path, "r") as f:
579
+ tok_config = json.load(f)
580
+ k = tok_config.get("k", 6)
581
+ auto_dna_tags = tok_config.get("auto_dna_tags", False)
582
+
583
+ return cls(base_tokenizer_path=pretrained_model_name_or_path, k=k, auto_dna_tags=auto_dna_tags, **kwargs)