ChuxiJ commited on
Commit
f2036a7
·
1 Parent(s): fe01169

fix max audio code id

Browse files
acestep/constrained_logits_processor.py CHANGED
@@ -20,6 +20,12 @@ from acestep.constants import (
20
  )
21
 
22
 
 
 
 
 
 
 
23
  # ==============================================================================
24
  # FSM States for Constrained Decoding
25
  # ==============================================================================
@@ -514,21 +520,34 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
514
  """
515
  Precompute audio code token IDs (tokens matching <|audio_code_\\d+|>).
516
  These tokens should be blocked during caption generation.
 
517
  """
518
  import re
519
- audio_code_pattern = re.compile(r'^<\|audio_code_\d+\|>$')
 
520
 
521
  # Iterate through vocabulary to find audio code tokens
522
  for token_id in range(self.vocab_size):
523
  try:
524
  token_text = self.tokenizer.decode([token_id])
525
- if audio_code_pattern.match(token_text):
526
- self.audio_code_token_ids.add(token_id)
 
 
 
 
 
 
 
 
 
527
  except Exception:
528
  continue
529
 
530
  if self.debug:
531
- logger.debug(f"Found {len(self.audio_code_token_ids)} audio code tokens")
 
 
532
 
533
  def _build_audio_code_mask(self):
534
  """
@@ -1503,6 +1522,24 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
1503
  self.non_audio_code_mask = self.non_audio_code_mask.to(device=scores.device, dtype=scores.dtype)
1504
  scores = scores + self.non_audio_code_mask
1505
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1506
  # Apply duration constraint in codes generation phase
1507
  if self.target_codes is not None and self.eos_token_id is not None:
1508
  if self.codes_count < self.target_codes:
 
20
  )
21
 
22
 
23
+ # ==============================================================================
24
+ # Constants
25
+ # ==============================================================================
26
+ # Maximum valid audio code value (codebook size = 64000, valid range: 0-63999)
27
+ MAX_AUDIO_CODE = 63999
28
+
29
  # ==============================================================================
30
  # FSM States for Constrained Decoding
31
  # ==============================================================================
 
520
  """
521
  Precompute audio code token IDs (tokens matching <|audio_code_\\d+|>).
522
  These tokens should be blocked during caption generation.
523
+ Only tokens with code values in range [0, MAX_AUDIO_CODE] are included.
524
  """
525
  import re
526
+ audio_code_pattern = re.compile(r'^<\|audio_code_(\d+)\|>$')
527
+ out_of_range_count = 0
528
 
529
  # Iterate through vocabulary to find audio code tokens
530
  for token_id in range(self.vocab_size):
531
  try:
532
  token_text = self.tokenizer.decode([token_id])
533
+ match = audio_code_pattern.match(token_text)
534
+ if match:
535
+ # Extract code value from token text
536
+ code_value = int(match.group(1))
537
+ # Only add tokens with valid code values (0-63999)
538
+ if 0 <= code_value <= MAX_AUDIO_CODE:
539
+ self.audio_code_token_ids.add(token_id)
540
+ else:
541
+ out_of_range_count += 1
542
+ if self.debug:
543
+ logger.debug(f"Skipping audio code token with out-of-range value: {token_text} (code={code_value})")
544
  except Exception:
545
  continue
546
 
547
  if self.debug:
548
+ logger.debug(f"Found {len(self.audio_code_token_ids)} valid audio code tokens (skipped {out_of_range_count} out-of-range tokens)")
549
+ if out_of_range_count > 0:
550
+ logger.warning(f"Skipped {out_of_range_count} audio code tokens with values outside valid range [0, {MAX_AUDIO_CODE}]")
551
 
552
  def _build_audio_code_mask(self):
553
  """
 
1522
  self.non_audio_code_mask = self.non_audio_code_mask.to(device=scores.device, dtype=scores.dtype)
1523
  scores = scores + self.non_audio_code_mask
1524
 
1525
+ # Additional validation: block audio code tokens with out-of-range values
1526
+ # This prevents generation of codes > MAX_AUDIO_CODE even if they exist in vocabulary
1527
+ import re
1528
+ audio_code_pattern = re.compile(r'^<\|audio_code_(\d+)\|>$')
1529
+ for token_id in self.audio_code_token_ids:
1530
+ try:
1531
+ token_text = self.tokenizer.decode([token_id])
1532
+ match = audio_code_pattern.match(token_text)
1533
+ if match:
1534
+ code_value = int(match.group(1))
1535
+ # Block tokens with code values outside valid range
1536
+ if code_value > MAX_AUDIO_CODE:
1537
+ scores[:, token_id] = float('-inf')
1538
+ if self.debug:
1539
+ logger.debug(f"Blocking out-of-range audio code token: {token_text} (code={code_value})")
1540
+ except Exception:
1541
+ continue
1542
+
1543
  # Apply duration constraint in codes generation phase
1544
  if self.target_codes is not None and self.eos_token_id is not None:
1545
  if self.codes_count < self.target_codes:
acestep/handler.py CHANGED
@@ -774,11 +774,32 @@ class AceStepHandler:
774
  return None
775
 
776
  def _parse_audio_code_string(self, code_str: str) -> List[int]:
777
- """Extract integer audio codes from prompt tokens like <|audio_code_123|>."""
 
 
 
778
  if not code_str:
779
  return []
780
  try:
781
- return [int(x) for x in re.findall(r"<\|audio_code_(\d+)\|>", code_str)]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
782
  except Exception as e:
783
  logger.debug(f"[_parse_audio_code_string] Failed to parse audio code string: {e}")
784
  return []
@@ -800,16 +821,23 @@ class AceStepHandler:
800
  detokenizer = self.model.detokenizer
801
 
802
  # Get codebook size for validation
803
- codebook_size = getattr(quantizer, 'codebook_size', 65536)
 
 
804
  if hasattr(quantizer, 'quantizers') and len(quantizer.quantizers) > 0:
805
  codebook_size = getattr(quantizer.quantizers[0], 'codebook_size', codebook_size)
806
 
807
- # Validate code IDs are within valid range
808
- invalid_codes = [c for c in code_ids if c < 0 or c >= codebook_size]
 
 
 
 
 
809
  if invalid_codes:
810
- logger.warning(f"[_decode_audio_codes_to_latents] Found {len(invalid_codes)} invalid codes out of range [0, {codebook_size}): {invalid_codes[:5]}...")
811
- # Clamp invalid codes to valid range
812
- code_ids = [max(0, min(c, codebook_size - 1)) for c in code_ids]
813
 
814
  num_quantizers = getattr(quantizer, "num_quantizers", 1)
815
  # Create indices tensor: [T_5Hz]
 
774
  return None
775
 
776
  def _parse_audio_code_string(self, code_str: str) -> List[int]:
777
+ """
778
+ Extract integer audio codes from prompt tokens like <|audio_code_123|>.
779
+ Clamps code values to valid range [0, 63999].
780
+ """
781
  if not code_str:
782
  return []
783
  try:
784
+ MAX_AUDIO_CODE = 63999
785
+ codes = [int(x) for x in re.findall(r"<\|audio_code_(\d+)\|>", code_str)]
786
+ # Clamp codes to valid range [0, 63999]
787
+ clamped_codes = []
788
+ clamped_count = 0
789
+ for code in codes:
790
+ if code < 0:
791
+ clamped_codes.append(0)
792
+ clamped_count += 1
793
+ elif code > MAX_AUDIO_CODE:
794
+ clamped_codes.append(MAX_AUDIO_CODE)
795
+ clamped_count += 1
796
+ else:
797
+ clamped_codes.append(code)
798
+
799
+ if clamped_count > 0:
800
+ logger.warning(f"[_parse_audio_code_string] Clamped {clamped_count} audio code values to valid range [0, {MAX_AUDIO_CODE}]")
801
+
802
+ return clamped_codes
803
  except Exception as e:
804
  logger.debug(f"[_parse_audio_code_string] Failed to parse audio code string: {e}")
805
  return []
 
821
  detokenizer = self.model.detokenizer
822
 
823
  # Get codebook size for validation
824
+ # DIT quantizer supports codebook size = 64000 (valid range: 0-63999)
825
+ MAX_AUDIO_CODE = 63999
826
+ codebook_size = getattr(quantizer, 'codebook_size', 64000)
827
  if hasattr(quantizer, 'quantizers') and len(quantizer.quantizers) > 0:
828
  codebook_size = getattr(quantizer.quantizers[0], 'codebook_size', codebook_size)
829
 
830
+ # Use 64000 as hard limit regardless of what quantizer reports
831
+ # This ensures compatibility with the actual DIT quantizer codebook size
832
+ effective_codebook_size = 64000
833
+ effective_max_code = MAX_AUDIO_CODE
834
+
835
+ # Validate code IDs are within valid range [0, 63999]
836
+ invalid_codes = [c for c in code_ids if c < 0 or c > effective_max_code]
837
  if invalid_codes:
838
+ logger.warning(f"[_decode_audio_codes_to_latents] Found {len(invalid_codes)} invalid codes out of range [0, {effective_max_code}]: {invalid_codes[:5]}...")
839
+ # Clamp invalid codes to valid range [0, 63999]
840
+ code_ids = [max(0, min(c, effective_max_code)) for c in code_ids]
841
 
842
  num_quantizers = getattr(quantizer, "num_quantizers", 1)
843
  # Create indices tensor: [T_5Hz]