ChuxiJ commited on
Commit
033008e
·
1 Parent(s): 048f54f

fix max audio code id

Browse files
acestep/constrained_logits_processor.py CHANGED
@@ -19,6 +19,9 @@ from acestep.constants import (
19
  VALID_TIME_SIGNATURES,
20
  )
21
 
 
 
 
22
 
23
  # ==============================================================================
24
  # FSM States for Constrained Decoding
@@ -514,21 +517,61 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
514
  """
515
  Precompute audio code token IDs (tokens matching <|audio_code_\\d+|>).
516
  These tokens should be blocked during caption generation.
 
517
  """
518
  import re
519
- audio_code_pattern = re.compile(r'^<\|audio_code_\d+\|>$')
 
520
 
521
  # Iterate through vocabulary to find audio code tokens
522
  for token_id in range(self.vocab_size):
523
  try:
524
  token_text = self.tokenizer.decode([token_id])
525
- if audio_code_pattern.match(token_text):
526
- self.audio_code_token_ids.add(token_id)
 
 
 
 
 
 
 
 
 
527
  except Exception:
528
  continue
529
 
530
- if self.debug:
531
- logger.debug(f"Found {len(self.audio_code_token_ids)} audio code tokens")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
532
 
533
  def _build_audio_code_mask(self):
534
  """
@@ -1497,6 +1540,8 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
1497
 
1498
  if self.state == FSMState.CODES_GENERATION:
1499
  # Block all non-audio-code tokens (only allow audio codes and EOS)
 
 
1500
  if self.non_audio_code_mask is not None:
1501
  # Move mask to same device/dtype as scores if needed
1502
  if self.non_audio_code_mask.device != scores.device or self.non_audio_code_mask.dtype != scores.dtype:
 
19
  VALID_TIME_SIGNATURES,
20
  )
21
 
22
+ # Maximum valid audio code value (codebook size = 64000)
23
+ MAX_AUDIO_CODE = 63999
24
+
25
 
26
  # ==============================================================================
27
  # FSM States for Constrained Decoding
 
517
  """
518
  Precompute audio code token IDs (tokens matching <|audio_code_\\d+|>).
519
  These tokens should be blocked during caption generation.
520
+ Only tokens with code values in range [0, MAX_AUDIO_CODE] are included.
521
  """
522
  import re
523
+ audio_code_pattern = re.compile(r'^<\|audio_code_(\d+)\|>$')
524
+ invalid_tokens_count = 0
525
 
526
  # Iterate through vocabulary to find audio code tokens
527
  for token_id in range(self.vocab_size):
528
  try:
529
  token_text = self.tokenizer.decode([token_id])
530
+ match = audio_code_pattern.match(token_text)
531
+ if match:
532
+ # Extract code value from token text
533
+ code_value = int(match.group(1))
534
+ # Only add tokens with valid code values (0-63999)
535
+ if 0 <= code_value <= MAX_AUDIO_CODE:
536
+ self.audio_code_token_ids.add(token_id)
537
+ else:
538
+ invalid_tokens_count += 1
539
+ if self.debug:
540
+ logger.debug(f"Skipping audio code token {token_id} with invalid code value {code_value} (max: {MAX_AUDIO_CODE})")
541
  except Exception:
542
  continue
543
 
544
+ if invalid_tokens_count > 0:
545
+ logger.warning(f"Found {invalid_tokens_count} audio code tokens with values outside valid range [0, {MAX_AUDIO_CODE}]")
546
+
547
+ # Log warning if no valid tokens found (this would prevent code generation)
548
+ if len(self.audio_code_token_ids) == 0:
549
+ logger.warning(f"No valid audio code tokens found in vocabulary (range [0, {MAX_AUDIO_CODE}]). Code generation may fail.")
550
+ elif self.debug:
551
+ logger.debug(f"Found {len(self.audio_code_token_ids)} valid audio code tokens (range [0, {MAX_AUDIO_CODE}])")
552
+
553
+ def _extract_code_from_token(self, token_id: int) -> Optional[int]:
554
+ """
555
+ Extract audio code value from a token ID.
556
+
557
+ Args:
558
+ token_id: Token ID to extract code value from
559
+
560
+ Returns:
561
+ Code value if token is a valid audio code token, None otherwise
562
+ """
563
+ import re
564
+ audio_code_pattern = re.compile(r'^<\|audio_code_(\d+)\|>$')
565
+
566
+ try:
567
+ token_text = self.tokenizer.decode([token_id])
568
+ match = audio_code_pattern.match(token_text)
569
+ if match:
570
+ return int(match.group(1))
571
+ except Exception:
572
+ pass
573
+
574
+ return None
575
 
576
  def _build_audio_code_mask(self):
577
  """
 
1540
 
1541
  if self.state == FSMState.CODES_GENERATION:
1542
  # Block all non-audio-code tokens (only allow audio codes and EOS)
1543
+ # Note: audio_code_token_ids already contains only valid tokens (0-63999 range)
1544
+ # because _precompute_audio_code_tokens() filters out invalid tokens during initialization
1545
  if self.non_audio_code_mask is not None:
1546
  # Move mask to same device/dtype as scores if needed
1547
  if self.non_audio_code_mask.device != scores.device or self.non_audio_code_mask.dtype != scores.dtype:
acestep/handler.py CHANGED
@@ -774,11 +774,29 @@ class AceStepHandler:
774
  return None
775
 
776
  def _parse_audio_code_string(self, code_str: str) -> List[int]:
777
- """Extract integer audio codes from prompt tokens like <|audio_code_123|>."""
 
 
778
  if not code_str:
779
  return []
780
  try:
781
- return [int(x) for x in re.findall(r"<\|audio_code_(\d+)\|>", code_str)]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
782
  except Exception as e:
783
  logger.debug(f"[_parse_audio_code_string] Failed to parse audio code string: {e}")
784
  return []
@@ -800,7 +818,8 @@ class AceStepHandler:
800
  detokenizer = self.model.detokenizer
801
 
802
  # Get codebook size for validation
803
- codebook_size = getattr(quantizer, 'codebook_size', 65536)
 
804
  if hasattr(quantizer, 'quantizers') and len(quantizer.quantizers) > 0:
805
  codebook_size = getattr(quantizer.quantizers[0], 'codebook_size', codebook_size)
806
 
 
774
  return None
775
 
776
  def _parse_audio_code_string(self, code_str: str) -> List[int]:
777
+ """Extract integer audio codes from prompt tokens like <|audio_code_123|>.
778
+ Codes are clamped to valid range [0, 63999] (codebook size = 64000).
779
+ """
780
  if not code_str:
781
  return []
782
  try:
783
+ codes = [int(x) for x in re.findall(r"<\|audio_code_(\d+)\|>", code_str)]
784
+ # Clamp codes to valid range [0, 63999]
785
+ MAX_AUDIO_CODE = 63999
786
+ clamped_codes = []
787
+ invalid_codes = []
788
+ for code in codes:
789
+ if code < 0 or code > MAX_AUDIO_CODE:
790
+ invalid_codes.append(code)
791
+ clamped_code = max(0, min(code, MAX_AUDIO_CODE))
792
+ clamped_codes.append(clamped_code)
793
+ else:
794
+ clamped_codes.append(code)
795
+
796
+ if invalid_codes:
797
+ logger.warning(f"[_parse_audio_code_string] Found {len(invalid_codes)} codes outside valid range [0, {MAX_AUDIO_CODE}]: {invalid_codes[:5]}... (clamped to valid range)")
798
+
799
+ return clamped_codes
800
  except Exception as e:
801
  logger.debug(f"[_parse_audio_code_string] Failed to parse audio code string: {e}")
802
  return []
 
818
  detokenizer = self.model.detokenizer
819
 
820
  # Get codebook size for validation
821
+ # Default to 64000 (codebook size = 64000, valid range = 0-63999)
822
+ codebook_size = getattr(quantizer, 'codebook_size', 64000)
823
  if hasattr(quantizer, 'quantizers') and len(quantizer.quantizers) > 0:
824
  codebook_size = getattr(quantizer.quantizers[0], 'codebook_size', codebook_size)
825