Spaces:
Running
on
A100
Running
on
A100
fix max audio code id
Browse files- acestep/constrained_logits_processor.py +50 -5
- acestep/handler.py +22 -3
acestep/constrained_logits_processor.py
CHANGED
|
@@ -19,6 +19,9 @@ from acestep.constants import (
|
|
| 19 |
VALID_TIME_SIGNATURES,
|
| 20 |
)
|
| 21 |
|
|
|
|
|
|
|
|
|
|
| 22 |
|
| 23 |
# ==============================================================================
|
| 24 |
# FSM States for Constrained Decoding
|
|
@@ -514,21 +517,61 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
|
|
| 514 |
"""
|
| 515 |
Precompute audio code token IDs (tokens matching <|audio_code_\\d+|>).
|
| 516 |
These tokens should be blocked during caption generation.
|
|
|
|
| 517 |
"""
|
| 518 |
import re
|
| 519 |
-
audio_code_pattern = re.compile(r'^<\|audio_code_\d
|
|
|
|
| 520 |
|
| 521 |
# Iterate through vocabulary to find audio code tokens
|
| 522 |
for token_id in range(self.vocab_size):
|
| 523 |
try:
|
| 524 |
token_text = self.tokenizer.decode([token_id])
|
| 525 |
-
|
| 526 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 527 |
except Exception:
|
| 528 |
continue
|
| 529 |
|
| 530 |
-
if
|
| 531 |
-
logger.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 532 |
|
| 533 |
def _build_audio_code_mask(self):
|
| 534 |
"""
|
|
@@ -1497,6 +1540,8 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
|
|
| 1497 |
|
| 1498 |
if self.state == FSMState.CODES_GENERATION:
|
| 1499 |
# Block all non-audio-code tokens (only allow audio codes and EOS)
|
|
|
|
|
|
|
| 1500 |
if self.non_audio_code_mask is not None:
|
| 1501 |
# Move mask to same device/dtype as scores if needed
|
| 1502 |
if self.non_audio_code_mask.device != scores.device or self.non_audio_code_mask.dtype != scores.dtype:
|
|
|
|
| 19 |
VALID_TIME_SIGNATURES,
|
| 20 |
)
|
| 21 |
|
| 22 |
+
# Maximum valid audio code value (codebook size = 64000)
|
| 23 |
+
MAX_AUDIO_CODE = 63999
|
| 24 |
+
|
| 25 |
|
| 26 |
# ==============================================================================
|
| 27 |
# FSM States for Constrained Decoding
|
|
|
|
| 517 |
"""
|
| 518 |
Precompute audio code token IDs (tokens matching <|audio_code_\\d+|>).
|
| 519 |
These tokens should be blocked during caption generation.
|
| 520 |
+
Only tokens with code values in range [0, MAX_AUDIO_CODE] are included.
|
| 521 |
"""
|
| 522 |
import re
|
| 523 |
+
audio_code_pattern = re.compile(r'^<\|audio_code_(\d+)\|>$')
|
| 524 |
+
invalid_tokens_count = 0
|
| 525 |
|
| 526 |
# Iterate through vocabulary to find audio code tokens
|
| 527 |
for token_id in range(self.vocab_size):
|
| 528 |
try:
|
| 529 |
token_text = self.tokenizer.decode([token_id])
|
| 530 |
+
match = audio_code_pattern.match(token_text)
|
| 531 |
+
if match:
|
| 532 |
+
# Extract code value from token text
|
| 533 |
+
code_value = int(match.group(1))
|
| 534 |
+
# Only add tokens with valid code values (0-63999)
|
| 535 |
+
if 0 <= code_value <= MAX_AUDIO_CODE:
|
| 536 |
+
self.audio_code_token_ids.add(token_id)
|
| 537 |
+
else:
|
| 538 |
+
invalid_tokens_count += 1
|
| 539 |
+
if self.debug:
|
| 540 |
+
logger.debug(f"Skipping audio code token {token_id} with invalid code value {code_value} (max: {MAX_AUDIO_CODE})")
|
| 541 |
except Exception:
|
| 542 |
continue
|
| 543 |
|
| 544 |
+
if invalid_tokens_count > 0:
|
| 545 |
+
logger.warning(f"Found {invalid_tokens_count} audio code tokens with values outside valid range [0, {MAX_AUDIO_CODE}]")
|
| 546 |
+
|
| 547 |
+
# Log warning if no valid tokens found (this would prevent code generation)
|
| 548 |
+
if len(self.audio_code_token_ids) == 0:
|
| 549 |
+
logger.warning(f"No valid audio code tokens found in vocabulary (range [0, {MAX_AUDIO_CODE}]). Code generation may fail.")
|
| 550 |
+
elif self.debug:
|
| 551 |
+
logger.debug(f"Found {len(self.audio_code_token_ids)} valid audio code tokens (range [0, {MAX_AUDIO_CODE}])")
|
| 552 |
+
|
| 553 |
+
def _extract_code_from_token(self, token_id: int) -> Optional[int]:
|
| 554 |
+
"""
|
| 555 |
+
Extract audio code value from a token ID.
|
| 556 |
+
|
| 557 |
+
Args:
|
| 558 |
+
token_id: Token ID to extract code value from
|
| 559 |
+
|
| 560 |
+
Returns:
|
| 561 |
+
Code value if token is a valid audio code token, None otherwise
|
| 562 |
+
"""
|
| 563 |
+
import re
|
| 564 |
+
audio_code_pattern = re.compile(r'^<\|audio_code_(\d+)\|>$')
|
| 565 |
+
|
| 566 |
+
try:
|
| 567 |
+
token_text = self.tokenizer.decode([token_id])
|
| 568 |
+
match = audio_code_pattern.match(token_text)
|
| 569 |
+
if match:
|
| 570 |
+
return int(match.group(1))
|
| 571 |
+
except Exception:
|
| 572 |
+
pass
|
| 573 |
+
|
| 574 |
+
return None
|
| 575 |
|
| 576 |
def _build_audio_code_mask(self):
|
| 577 |
"""
|
|
|
|
| 1540 |
|
| 1541 |
if self.state == FSMState.CODES_GENERATION:
|
| 1542 |
# Block all non-audio-code tokens (only allow audio codes and EOS)
|
| 1543 |
+
# Note: audio_code_token_ids already contains only valid tokens (0-63999 range)
|
| 1544 |
+
# because _precompute_audio_code_tokens() filters out invalid tokens during initialization
|
| 1545 |
if self.non_audio_code_mask is not None:
|
| 1546 |
# Move mask to same device/dtype as scores if needed
|
| 1547 |
if self.non_audio_code_mask.device != scores.device or self.non_audio_code_mask.dtype != scores.dtype:
|
acestep/handler.py
CHANGED
|
@@ -774,11 +774,29 @@ class AceStepHandler:
|
|
| 774 |
return None
|
| 775 |
|
| 776 |
def _parse_audio_code_string(self, code_str: str) -> List[int]:
|
| 777 |
-
"""Extract integer audio codes from prompt tokens like <|audio_code_123|>.
|
|
|
|
|
|
|
| 778 |
if not code_str:
|
| 779 |
return []
|
| 780 |
try:
|
| 781 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 782 |
except Exception as e:
|
| 783 |
logger.debug(f"[_parse_audio_code_string] Failed to parse audio code string: {e}")
|
| 784 |
return []
|
|
@@ -800,7 +818,8 @@ class AceStepHandler:
|
|
| 800 |
detokenizer = self.model.detokenizer
|
| 801 |
|
| 802 |
# Get codebook size for validation
|
| 803 |
-
|
|
|
|
| 804 |
if hasattr(quantizer, 'quantizers') and len(quantizer.quantizers) > 0:
|
| 805 |
codebook_size = getattr(quantizer.quantizers[0], 'codebook_size', codebook_size)
|
| 806 |
|
|
|
|
| 774 |
return None
|
| 775 |
|
| 776 |
def _parse_audio_code_string(self, code_str: str) -> List[int]:
|
| 777 |
+
"""Extract integer audio codes from prompt tokens like <|audio_code_123|>.
|
| 778 |
+
Codes are clamped to valid range [0, 63999] (codebook size = 64000).
|
| 779 |
+
"""
|
| 780 |
if not code_str:
|
| 781 |
return []
|
| 782 |
try:
|
| 783 |
+
codes = [int(x) for x in re.findall(r"<\|audio_code_(\d+)\|>", code_str)]
|
| 784 |
+
# Clamp codes to valid range [0, 63999]
|
| 785 |
+
MAX_AUDIO_CODE = 63999
|
| 786 |
+
clamped_codes = []
|
| 787 |
+
invalid_codes = []
|
| 788 |
+
for code in codes:
|
| 789 |
+
if code < 0 or code > MAX_AUDIO_CODE:
|
| 790 |
+
invalid_codes.append(code)
|
| 791 |
+
clamped_code = max(0, min(code, MAX_AUDIO_CODE))
|
| 792 |
+
clamped_codes.append(clamped_code)
|
| 793 |
+
else:
|
| 794 |
+
clamped_codes.append(code)
|
| 795 |
+
|
| 796 |
+
if invalid_codes:
|
| 797 |
+
logger.warning(f"[_parse_audio_code_string] Found {len(invalid_codes)} codes outside valid range [0, {MAX_AUDIO_CODE}]: {invalid_codes[:5]}... (clamped to valid range)")
|
| 798 |
+
|
| 799 |
+
return clamped_codes
|
| 800 |
except Exception as e:
|
| 801 |
logger.debug(f"[_parse_audio_code_string] Failed to parse audio code string: {e}")
|
| 802 |
return []
|
|
|
|
| 818 |
detokenizer = self.model.detokenizer
|
| 819 |
|
| 820 |
# Get codebook size for validation
|
| 821 |
+
# Default to 64000 (codebook size = 64000, valid range = 0-63999)
|
| 822 |
+
codebook_size = getattr(quantizer, 'codebook_size', 64000)
|
| 823 |
if hasattr(quantizer, 'quantizers') and len(quantizer.quantizers) > 0:
|
| 824 |
codebook_size = getattr(quantizer.quantizers[0], 'codebook_size', codebook_size)
|
| 825 |
|