Spaces:
Sleeping
Sleeping
fix max audio code id
Browse files- acestep/constrained_logits_processor.py +41 -4
- acestep/handler.py +36 -8
acestep/constrained_logits_processor.py
CHANGED
|
@@ -20,6 +20,12 @@ from acestep.constants import (
|
|
| 20 |
)
|
| 21 |
|
| 22 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
# ==============================================================================
|
| 24 |
# FSM States for Constrained Decoding
|
| 25 |
# ==============================================================================
|
|
@@ -514,21 +520,34 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
|
|
| 514 |
"""
|
| 515 |
Precompute audio code token IDs (tokens matching <|audio_code_\\d+|>).
|
| 516 |
These tokens should be blocked during caption generation.
|
|
|
|
| 517 |
"""
|
| 518 |
import re
|
| 519 |
-
audio_code_pattern = re.compile(r'^<\|audio_code_\d+\|>$')
|
|
|
|
| 520 |
|
| 521 |
# Iterate through vocabulary to find audio code tokens
|
| 522 |
for token_id in range(self.vocab_size):
|
| 523 |
try:
|
| 524 |
token_text = self.tokenizer.decode([token_id])
|
| 525 |
-
|
| 526 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 527 |
except Exception:
|
| 528 |
continue
|
| 529 |
|
| 530 |
if self.debug:
|
| 531 |
-
logger.debug(f"Found {len(self.audio_code_token_ids)} audio code tokens")
|
|
|
|
|
|
|
| 532 |
|
| 533 |
def _build_audio_code_mask(self):
|
| 534 |
"""
|
|
@@ -1503,6 +1522,24 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
|
|
| 1503 |
self.non_audio_code_mask = self.non_audio_code_mask.to(device=scores.device, dtype=scores.dtype)
|
| 1504 |
scores = scores + self.non_audio_code_mask
|
| 1505 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1506 |
# Apply duration constraint in codes generation phase
|
| 1507 |
if self.target_codes is not None and self.eos_token_id is not None:
|
| 1508 |
if self.codes_count < self.target_codes:
|
|
|
|
| 20 |
)
|
| 21 |
|
| 22 |
|
| 23 |
+
# ==============================================================================
|
| 24 |
+
# Constants
|
| 25 |
+
# ==============================================================================
|
| 26 |
+
# Maximum valid audio code value (codebook size = 64000, valid range: 0-63999)
|
| 27 |
+
MAX_AUDIO_CODE = 63999
|
| 28 |
+
|
| 29 |
# ==============================================================================
|
| 30 |
# FSM States for Constrained Decoding
|
| 31 |
# ==============================================================================
|
|
|
|
| 520 |
"""
|
| 521 |
Precompute audio code token IDs (tokens matching <|audio_code_\\d+|>).
|
| 522 |
These tokens should be blocked during caption generation.
|
| 523 |
+
Only tokens with code values in range [0, MAX_AUDIO_CODE] are included.
|
| 524 |
"""
|
| 525 |
import re
|
| 526 |
+
audio_code_pattern = re.compile(r'^<\|audio_code_(\d+)\|>$')
|
| 527 |
+
out_of_range_count = 0
|
| 528 |
|
| 529 |
# Iterate through vocabulary to find audio code tokens
|
| 530 |
for token_id in range(self.vocab_size):
|
| 531 |
try:
|
| 532 |
token_text = self.tokenizer.decode([token_id])
|
| 533 |
+
match = audio_code_pattern.match(token_text)
|
| 534 |
+
if match:
|
| 535 |
+
# Extract code value from token text
|
| 536 |
+
code_value = int(match.group(1))
|
| 537 |
+
# Only add tokens with valid code values (0-63999)
|
| 538 |
+
if 0 <= code_value <= MAX_AUDIO_CODE:
|
| 539 |
+
self.audio_code_token_ids.add(token_id)
|
| 540 |
+
else:
|
| 541 |
+
out_of_range_count += 1
|
| 542 |
+
if self.debug:
|
| 543 |
+
logger.debug(f"Skipping audio code token with out-of-range value: {token_text} (code={code_value})")
|
| 544 |
except Exception:
|
| 545 |
continue
|
| 546 |
|
| 547 |
if self.debug:
|
| 548 |
+
logger.debug(f"Found {len(self.audio_code_token_ids)} valid audio code tokens (skipped {out_of_range_count} out-of-range tokens)")
|
| 549 |
+
if out_of_range_count > 0:
|
| 550 |
+
logger.warning(f"Skipped {out_of_range_count} audio code tokens with values outside valid range [0, {MAX_AUDIO_CODE}]")
|
| 551 |
|
| 552 |
def _build_audio_code_mask(self):
|
| 553 |
"""
|
|
|
|
| 1522 |
self.non_audio_code_mask = self.non_audio_code_mask.to(device=scores.device, dtype=scores.dtype)
|
| 1523 |
scores = scores + self.non_audio_code_mask
|
| 1524 |
|
| 1525 |
+
# Additional validation: block audio code tokens with out-of-range values
|
| 1526 |
+
# This prevents generation of codes > MAX_AUDIO_CODE even if they exist in vocabulary
|
| 1527 |
+
import re
|
| 1528 |
+
audio_code_pattern = re.compile(r'^<\|audio_code_(\d+)\|>$')
|
| 1529 |
+
for token_id in self.audio_code_token_ids:
|
| 1530 |
+
try:
|
| 1531 |
+
token_text = self.tokenizer.decode([token_id])
|
| 1532 |
+
match = audio_code_pattern.match(token_text)
|
| 1533 |
+
if match:
|
| 1534 |
+
code_value = int(match.group(1))
|
| 1535 |
+
# Block tokens with code values outside valid range
|
| 1536 |
+
if code_value > MAX_AUDIO_CODE:
|
| 1537 |
+
scores[:, token_id] = float('-inf')
|
| 1538 |
+
if self.debug:
|
| 1539 |
+
logger.debug(f"Blocking out-of-range audio code token: {token_text} (code={code_value})")
|
| 1540 |
+
except Exception:
|
| 1541 |
+
continue
|
| 1542 |
+
|
| 1543 |
# Apply duration constraint in codes generation phase
|
| 1544 |
if self.target_codes is not None and self.eos_token_id is not None:
|
| 1545 |
if self.codes_count < self.target_codes:
|
acestep/handler.py
CHANGED
|
@@ -774,11 +774,32 @@ class AceStepHandler:
|
|
| 774 |
return None
|
| 775 |
|
| 776 |
def _parse_audio_code_string(self, code_str: str) -> List[int]:
|
| 777 |
-
"""
|
|
|
|
|
|
|
|
|
|
| 778 |
if not code_str:
|
| 779 |
return []
|
| 780 |
try:
|
| 781 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 782 |
except Exception as e:
|
| 783 |
logger.debug(f"[_parse_audio_code_string] Failed to parse audio code string: {e}")
|
| 784 |
return []
|
|
@@ -800,16 +821,23 @@ class AceStepHandler:
|
|
| 800 |
detokenizer = self.model.detokenizer
|
| 801 |
|
| 802 |
# Get codebook size for validation
|
| 803 |
-
|
|
|
|
|
|
|
| 804 |
if hasattr(quantizer, 'quantizers') and len(quantizer.quantizers) > 0:
|
| 805 |
codebook_size = getattr(quantizer.quantizers[0], 'codebook_size', codebook_size)
|
| 806 |
|
| 807 |
-
#
|
| 808 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 809 |
if invalid_codes:
|
| 810 |
-
logger.warning(f"[_decode_audio_codes_to_latents] Found {len(invalid_codes)} invalid codes out of range [0, {
|
| 811 |
-
# Clamp invalid codes to valid range
|
| 812 |
-
code_ids = [max(0, min(c,
|
| 813 |
|
| 814 |
num_quantizers = getattr(quantizer, "num_quantizers", 1)
|
| 815 |
# Create indices tensor: [T_5Hz]
|
|
|
|
| 774 |
return None
|
| 775 |
|
| 776 |
def _parse_audio_code_string(self, code_str: str) -> List[int]:
|
| 777 |
+
"""
|
| 778 |
+
Extract integer audio codes from prompt tokens like <|audio_code_123|>.
|
| 779 |
+
Clamps code values to valid range [0, 63999].
|
| 780 |
+
"""
|
| 781 |
if not code_str:
|
| 782 |
return []
|
| 783 |
try:
|
| 784 |
+
MAX_AUDIO_CODE = 63999
|
| 785 |
+
codes = [int(x) for x in re.findall(r"<\|audio_code_(\d+)\|>", code_str)]
|
| 786 |
+
# Clamp codes to valid range [0, 63999]
|
| 787 |
+
clamped_codes = []
|
| 788 |
+
clamped_count = 0
|
| 789 |
+
for code in codes:
|
| 790 |
+
if code < 0:
|
| 791 |
+
clamped_codes.append(0)
|
| 792 |
+
clamped_count += 1
|
| 793 |
+
elif code > MAX_AUDIO_CODE:
|
| 794 |
+
clamped_codes.append(MAX_AUDIO_CODE)
|
| 795 |
+
clamped_count += 1
|
| 796 |
+
else:
|
| 797 |
+
clamped_codes.append(code)
|
| 798 |
+
|
| 799 |
+
if clamped_count > 0:
|
| 800 |
+
logger.warning(f"[_parse_audio_code_string] Clamped {clamped_count} audio code values to valid range [0, {MAX_AUDIO_CODE}]")
|
| 801 |
+
|
| 802 |
+
return clamped_codes
|
| 803 |
except Exception as e:
|
| 804 |
logger.debug(f"[_parse_audio_code_string] Failed to parse audio code string: {e}")
|
| 805 |
return []
|
|
|
|
| 821 |
detokenizer = self.model.detokenizer
|
| 822 |
|
| 823 |
# Get codebook size for validation
|
| 824 |
+
# DIT quantizer supports codebook size = 64000 (valid range: 0-63999)
|
| 825 |
+
MAX_AUDIO_CODE = 63999
|
| 826 |
+
codebook_size = getattr(quantizer, 'codebook_size', 64000)
|
| 827 |
if hasattr(quantizer, 'quantizers') and len(quantizer.quantizers) > 0:
|
| 828 |
codebook_size = getattr(quantizer.quantizers[0], 'codebook_size', codebook_size)
|
| 829 |
|
| 830 |
+
# Use 64000 as hard limit regardless of what quantizer reports
|
| 831 |
+
# This ensures compatibility with the actual DIT quantizer codebook size
|
| 832 |
+
effective_codebook_size = 64000
|
| 833 |
+
effective_max_code = MAX_AUDIO_CODE
|
| 834 |
+
|
| 835 |
+
# Validate code IDs are within valid range [0, 63999]
|
| 836 |
+
invalid_codes = [c for c in code_ids if c < 0 or c > effective_max_code]
|
| 837 |
if invalid_codes:
|
| 838 |
+
logger.warning(f"[_decode_audio_codes_to_latents] Found {len(invalid_codes)} invalid codes out of range [0, {effective_max_code}]: {invalid_codes[:5]}...")
|
| 839 |
+
# Clamp invalid codes to valid range [0, 63999]
|
| 840 |
+
code_ids = [max(0, min(c, effective_max_code)) for c in code_ids]
|
| 841 |
|
| 842 |
num_quantizers = getattr(quantizer, "num_quantizers", 1)
|
| 843 |
# Create indices tensor: [T_5Hz]
|