ChuxiJ commited on
Commit
048f54f
·
1 Parent(s): f2036a7

reverse fix

Browse files
acestep/constrained_logits_processor.py CHANGED
@@ -20,12 +20,6 @@ from acestep.constants import (
20
  )
21
 
22
 
23
- # ==============================================================================
24
- # Constants
25
- # ==============================================================================
26
- # Maximum valid audio code value (codebook size = 64000, valid range: 0-63999)
27
- MAX_AUDIO_CODE = 63999
28
-
29
  # ==============================================================================
30
  # FSM States for Constrained Decoding
31
  # ==============================================================================
@@ -520,34 +514,21 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
520
  """
521
  Precompute audio code token IDs (tokens matching <|audio_code_\\d+|>).
522
  These tokens should be blocked during caption generation.
523
- Only tokens with code values in range [0, MAX_AUDIO_CODE] are included.
524
  """
525
  import re
526
- audio_code_pattern = re.compile(r'^<\|audio_code_(\d+)\|>$')
527
- out_of_range_count = 0
528
 
529
  # Iterate through vocabulary to find audio code tokens
530
  for token_id in range(self.vocab_size):
531
  try:
532
  token_text = self.tokenizer.decode([token_id])
533
- match = audio_code_pattern.match(token_text)
534
- if match:
535
- # Extract code value from token text
536
- code_value = int(match.group(1))
537
- # Only add tokens with valid code values (0-63999)
538
- if 0 <= code_value <= MAX_AUDIO_CODE:
539
- self.audio_code_token_ids.add(token_id)
540
- else:
541
- out_of_range_count += 1
542
- if self.debug:
543
- logger.debug(f"Skipping audio code token with out-of-range value: {token_text} (code={code_value})")
544
  except Exception:
545
  continue
546
 
547
  if self.debug:
548
- logger.debug(f"Found {len(self.audio_code_token_ids)} valid audio code tokens (skipped {out_of_range_count} out-of-range tokens)")
549
- if out_of_range_count > 0:
550
- logger.warning(f"Skipped {out_of_range_count} audio code tokens with values outside valid range [0, {MAX_AUDIO_CODE}]")
551
 
552
  def _build_audio_code_mask(self):
553
  """
@@ -1522,24 +1503,6 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
1522
  self.non_audio_code_mask = self.non_audio_code_mask.to(device=scores.device, dtype=scores.dtype)
1523
  scores = scores + self.non_audio_code_mask
1524
 
1525
- # Additional validation: block audio code tokens with out-of-range values
1526
- # This prevents generation of codes > MAX_AUDIO_CODE even if they exist in vocabulary
1527
- import re
1528
- audio_code_pattern = re.compile(r'^<\|audio_code_(\d+)\|>$')
1529
- for token_id in self.audio_code_token_ids:
1530
- try:
1531
- token_text = self.tokenizer.decode([token_id])
1532
- match = audio_code_pattern.match(token_text)
1533
- if match:
1534
- code_value = int(match.group(1))
1535
- # Block tokens with code values outside valid range
1536
- if code_value > MAX_AUDIO_CODE:
1537
- scores[:, token_id] = float('-inf')
1538
- if self.debug:
1539
- logger.debug(f"Blocking out-of-range audio code token: {token_text} (code={code_value})")
1540
- except Exception:
1541
- continue
1542
-
1543
  # Apply duration constraint in codes generation phase
1544
  if self.target_codes is not None and self.eos_token_id is not None:
1545
  if self.codes_count < self.target_codes:
 
20
  )
21
 
22
 
 
 
 
 
 
 
23
  # ==============================================================================
24
  # FSM States for Constrained Decoding
25
  # ==============================================================================
 
514
  """
515
  Precompute audio code token IDs (tokens matching <|audio_code_\\d+|>).
516
  These tokens should be blocked during caption generation.
 
517
  """
518
  import re
519
+ audio_code_pattern = re.compile(r'^<\|audio_code_\d+\|>$')
 
520
 
521
  # Iterate through vocabulary to find audio code tokens
522
  for token_id in range(self.vocab_size):
523
  try:
524
  token_text = self.tokenizer.decode([token_id])
525
+ if audio_code_pattern.match(token_text):
526
+ self.audio_code_token_ids.add(token_id)
 
 
 
 
 
 
 
 
 
527
  except Exception:
528
  continue
529
 
530
  if self.debug:
531
+ logger.debug(f"Found {len(self.audio_code_token_ids)} audio code tokens")
 
 
532
 
533
  def _build_audio_code_mask(self):
534
  """
 
1503
  self.non_audio_code_mask = self.non_audio_code_mask.to(device=scores.device, dtype=scores.dtype)
1504
  scores = scores + self.non_audio_code_mask
1505
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1506
  # Apply duration constraint in codes generation phase
1507
  if self.target_codes is not None and self.eos_token_id is not None:
1508
  if self.codes_count < self.target_codes:
acestep/handler.py CHANGED
@@ -774,32 +774,11 @@ class AceStepHandler:
774
  return None
775
 
776
  def _parse_audio_code_string(self, code_str: str) -> List[int]:
777
- """
778
- Extract integer audio codes from prompt tokens like <|audio_code_123|>.
779
- Clamps code values to valid range [0, 63999].
780
- """
781
  if not code_str:
782
  return []
783
  try:
784
- MAX_AUDIO_CODE = 63999
785
- codes = [int(x) for x in re.findall(r"<\|audio_code_(\d+)\|>", code_str)]
786
- # Clamp codes to valid range [0, 63999]
787
- clamped_codes = []
788
- clamped_count = 0
789
- for code in codes:
790
- if code < 0:
791
- clamped_codes.append(0)
792
- clamped_count += 1
793
- elif code > MAX_AUDIO_CODE:
794
- clamped_codes.append(MAX_AUDIO_CODE)
795
- clamped_count += 1
796
- else:
797
- clamped_codes.append(code)
798
-
799
- if clamped_count > 0:
800
- logger.warning(f"[_parse_audio_code_string] Clamped {clamped_count} audio code values to valid range [0, {MAX_AUDIO_CODE}]")
801
-
802
- return clamped_codes
803
  except Exception as e:
804
  logger.debug(f"[_parse_audio_code_string] Failed to parse audio code string: {e}")
805
  return []
@@ -821,23 +800,16 @@ class AceStepHandler:
821
  detokenizer = self.model.detokenizer
822
 
823
  # Get codebook size for validation
824
- # DIT quantizer supports codebook size = 64000 (valid range: 0-63999)
825
- MAX_AUDIO_CODE = 63999
826
- codebook_size = getattr(quantizer, 'codebook_size', 64000)
827
  if hasattr(quantizer, 'quantizers') and len(quantizer.quantizers) > 0:
828
  codebook_size = getattr(quantizer.quantizers[0], 'codebook_size', codebook_size)
829
 
830
- # Use 64000 as hard limit regardless of what quantizer reports
831
- # This ensures compatibility with the actual DIT quantizer codebook size
832
- effective_codebook_size = 64000
833
- effective_max_code = MAX_AUDIO_CODE
834
-
835
- # Validate code IDs are within valid range [0, 63999]
836
- invalid_codes = [c for c in code_ids if c < 0 or c > effective_max_code]
837
  if invalid_codes:
838
- logger.warning(f"[_decode_audio_codes_to_latents] Found {len(invalid_codes)} invalid codes out of range [0, {effective_max_code}]: {invalid_codes[:5]}...")
839
- # Clamp invalid codes to valid range [0, 63999]
840
- code_ids = [max(0, min(c, effective_max_code)) for c in code_ids]
841
 
842
  num_quantizers = getattr(quantizer, "num_quantizers", 1)
843
  # Create indices tensor: [T_5Hz]
 
774
  return None
775
 
776
  def _parse_audio_code_string(self, code_str: str) -> List[int]:
777
+ """Extract integer audio codes from prompt tokens like <|audio_code_123|>."""
 
 
 
778
  if not code_str:
779
  return []
780
  try:
781
+ return [int(x) for x in re.findall(r"<\|audio_code_(\d+)\|>", code_str)]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
782
  except Exception as e:
783
  logger.debug(f"[_parse_audio_code_string] Failed to parse audio code string: {e}")
784
  return []
 
800
  detokenizer = self.model.detokenizer
801
 
802
  # Get codebook size for validation
803
+ codebook_size = getattr(quantizer, 'codebook_size', 65536)
 
 
804
  if hasattr(quantizer, 'quantizers') and len(quantizer.quantizers) > 0:
805
  codebook_size = getattr(quantizer.quantizers[0], 'codebook_size', codebook_size)
806
 
807
+ # Validate code IDs are within valid range
808
+ invalid_codes = [c for c in code_ids if c < 0 or c >= codebook_size]
 
 
 
 
 
809
  if invalid_codes:
810
+ logger.warning(f"[_decode_audio_codes_to_latents] Found {len(invalid_codes)} invalid codes out of range [0, {codebook_size}): {invalid_codes[:5]}...")
811
+ # Clamp invalid codes to valid range
812
+ code_ids = [max(0, min(c, codebook_size - 1)) for c in code_ids]
813
 
814
  num_quantizers = getattr(quantizer, "num_quantizers", 1)
815
  # Create indices tensor: [T_5Hz]