Spaces:
Running
on
A100
Running
on
A100
reverse fix
Browse files- acestep/constrained_logits_processor.py +4 -41
- acestep/handler.py +8 -36
acestep/constrained_logits_processor.py
CHANGED
|
@@ -20,12 +20,6 @@ from acestep.constants import (
|
|
| 20 |
)
|
| 21 |
|
| 22 |
|
| 23 |
-
# ==============================================================================
|
| 24 |
-
# Constants
|
| 25 |
-
# ==============================================================================
|
| 26 |
-
# Maximum valid audio code value (codebook size = 64000, valid range: 0-63999)
|
| 27 |
-
MAX_AUDIO_CODE = 63999
|
| 28 |
-
|
| 29 |
# ==============================================================================
|
| 30 |
# FSM States for Constrained Decoding
|
| 31 |
# ==============================================================================
|
|
@@ -520,34 +514,21 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
|
|
| 520 |
"""
|
| 521 |
Precompute audio code token IDs (tokens matching <|audio_code_\\d+|>).
|
| 522 |
These tokens should be blocked during caption generation.
|
| 523 |
-
Only tokens with code values in range [0, MAX_AUDIO_CODE] are included.
|
| 524 |
"""
|
| 525 |
import re
|
| 526 |
-
audio_code_pattern = re.compile(r'^<\|audio_code_
|
| 527 |
-
out_of_range_count = 0
|
| 528 |
|
| 529 |
# Iterate through vocabulary to find audio code tokens
|
| 530 |
for token_id in range(self.vocab_size):
|
| 531 |
try:
|
| 532 |
token_text = self.tokenizer.decode([token_id])
|
| 533 |
-
|
| 534 |
-
|
| 535 |
-
# Extract code value from token text
|
| 536 |
-
code_value = int(match.group(1))
|
| 537 |
-
# Only add tokens with valid code values (0-63999)
|
| 538 |
-
if 0 <= code_value <= MAX_AUDIO_CODE:
|
| 539 |
-
self.audio_code_token_ids.add(token_id)
|
| 540 |
-
else:
|
| 541 |
-
out_of_range_count += 1
|
| 542 |
-
if self.debug:
|
| 543 |
-
logger.debug(f"Skipping audio code token with out-of-range value: {token_text} (code={code_value})")
|
| 544 |
except Exception:
|
| 545 |
continue
|
| 546 |
|
| 547 |
if self.debug:
|
| 548 |
-
logger.debug(f"Found {len(self.audio_code_token_ids)}
|
| 549 |
-
if out_of_range_count > 0:
|
| 550 |
-
logger.warning(f"Skipped {out_of_range_count} audio code tokens with values outside valid range [0, {MAX_AUDIO_CODE}]")
|
| 551 |
|
| 552 |
def _build_audio_code_mask(self):
|
| 553 |
"""
|
|
@@ -1522,24 +1503,6 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
|
|
| 1522 |
self.non_audio_code_mask = self.non_audio_code_mask.to(device=scores.device, dtype=scores.dtype)
|
| 1523 |
scores = scores + self.non_audio_code_mask
|
| 1524 |
|
| 1525 |
-
# Additional validation: block audio code tokens with out-of-range values
|
| 1526 |
-
# This prevents generation of codes > MAX_AUDIO_CODE even if they exist in vocabulary
|
| 1527 |
-
import re
|
| 1528 |
-
audio_code_pattern = re.compile(r'^<\|audio_code_(\d+)\|>$')
|
| 1529 |
-
for token_id in self.audio_code_token_ids:
|
| 1530 |
-
try:
|
| 1531 |
-
token_text = self.tokenizer.decode([token_id])
|
| 1532 |
-
match = audio_code_pattern.match(token_text)
|
| 1533 |
-
if match:
|
| 1534 |
-
code_value = int(match.group(1))
|
| 1535 |
-
# Block tokens with code values outside valid range
|
| 1536 |
-
if code_value > MAX_AUDIO_CODE:
|
| 1537 |
-
scores[:, token_id] = float('-inf')
|
| 1538 |
-
if self.debug:
|
| 1539 |
-
logger.debug(f"Blocking out-of-range audio code token: {token_text} (code={code_value})")
|
| 1540 |
-
except Exception:
|
| 1541 |
-
continue
|
| 1542 |
-
|
| 1543 |
# Apply duration constraint in codes generation phase
|
| 1544 |
if self.target_codes is not None and self.eos_token_id is not None:
|
| 1545 |
if self.codes_count < self.target_codes:
|
|
|
|
| 20 |
)
|
| 21 |
|
| 22 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
# ==============================================================================
|
| 24 |
# FSM States for Constrained Decoding
|
| 25 |
# ==============================================================================
|
|
|
|
| 514 |
"""
|
| 515 |
Precompute audio code token IDs (tokens matching <|audio_code_\\d+|>).
|
| 516 |
These tokens should be blocked during caption generation.
|
|
|
|
| 517 |
"""
|
| 518 |
import re
|
| 519 |
+
audio_code_pattern = re.compile(r'^<\|audio_code_\d+\|>$')
|
|
|
|
| 520 |
|
| 521 |
# Iterate through vocabulary to find audio code tokens
|
| 522 |
for token_id in range(self.vocab_size):
|
| 523 |
try:
|
| 524 |
token_text = self.tokenizer.decode([token_id])
|
| 525 |
+
if audio_code_pattern.match(token_text):
|
| 526 |
+
self.audio_code_token_ids.add(token_id)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 527 |
except Exception:
|
| 528 |
continue
|
| 529 |
|
| 530 |
if self.debug:
|
| 531 |
+
logger.debug(f"Found {len(self.audio_code_token_ids)} audio code tokens")
|
|
|
|
|
|
|
| 532 |
|
| 533 |
def _build_audio_code_mask(self):
|
| 534 |
"""
|
|
|
|
| 1503 |
self.non_audio_code_mask = self.non_audio_code_mask.to(device=scores.device, dtype=scores.dtype)
|
| 1504 |
scores = scores + self.non_audio_code_mask
|
| 1505 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1506 |
# Apply duration constraint in codes generation phase
|
| 1507 |
if self.target_codes is not None and self.eos_token_id is not None:
|
| 1508 |
if self.codes_count < self.target_codes:
|
acestep/handler.py
CHANGED
|
@@ -774,32 +774,11 @@ class AceStepHandler:
|
|
| 774 |
return None
|
| 775 |
|
| 776 |
def _parse_audio_code_string(self, code_str: str) -> List[int]:
|
| 777 |
-
"""
|
| 778 |
-
Extract integer audio codes from prompt tokens like <|audio_code_123|>.
|
| 779 |
-
Clamps code values to valid range [0, 63999].
|
| 780 |
-
"""
|
| 781 |
if not code_str:
|
| 782 |
return []
|
| 783 |
try:
|
| 784 |
-
|
| 785 |
-
codes = [int(x) for x in re.findall(r"<\|audio_code_(\d+)\|>", code_str)]
|
| 786 |
-
# Clamp codes to valid range [0, 63999]
|
| 787 |
-
clamped_codes = []
|
| 788 |
-
clamped_count = 0
|
| 789 |
-
for code in codes:
|
| 790 |
-
if code < 0:
|
| 791 |
-
clamped_codes.append(0)
|
| 792 |
-
clamped_count += 1
|
| 793 |
-
elif code > MAX_AUDIO_CODE:
|
| 794 |
-
clamped_codes.append(MAX_AUDIO_CODE)
|
| 795 |
-
clamped_count += 1
|
| 796 |
-
else:
|
| 797 |
-
clamped_codes.append(code)
|
| 798 |
-
|
| 799 |
-
if clamped_count > 0:
|
| 800 |
-
logger.warning(f"[_parse_audio_code_string] Clamped {clamped_count} audio code values to valid range [0, {MAX_AUDIO_CODE}]")
|
| 801 |
-
|
| 802 |
-
return clamped_codes
|
| 803 |
except Exception as e:
|
| 804 |
logger.debug(f"[_parse_audio_code_string] Failed to parse audio code string: {e}")
|
| 805 |
return []
|
|
@@ -821,23 +800,16 @@ class AceStepHandler:
|
|
| 821 |
detokenizer = self.model.detokenizer
|
| 822 |
|
| 823 |
# Get codebook size for validation
|
| 824 |
-
|
| 825 |
-
MAX_AUDIO_CODE = 63999
|
| 826 |
-
codebook_size = getattr(quantizer, 'codebook_size', 64000)
|
| 827 |
if hasattr(quantizer, 'quantizers') and len(quantizer.quantizers) > 0:
|
| 828 |
codebook_size = getattr(quantizer.quantizers[0], 'codebook_size', codebook_size)
|
| 829 |
|
| 830 |
-
#
|
| 831 |
-
|
| 832 |
-
effective_codebook_size = 64000
|
| 833 |
-
effective_max_code = MAX_AUDIO_CODE
|
| 834 |
-
|
| 835 |
-
# Validate code IDs are within valid range [0, 63999]
|
| 836 |
-
invalid_codes = [c for c in code_ids if c < 0 or c > effective_max_code]
|
| 837 |
if invalid_codes:
|
| 838 |
-
logger.warning(f"[_decode_audio_codes_to_latents] Found {len(invalid_codes)} invalid codes out of range [0, {
|
| 839 |
-
# Clamp invalid codes to valid range
|
| 840 |
-
code_ids = [max(0, min(c,
|
| 841 |
|
| 842 |
num_quantizers = getattr(quantizer, "num_quantizers", 1)
|
| 843 |
# Create indices tensor: [T_5Hz]
|
|
|
|
| 774 |
return None
|
| 775 |
|
| 776 |
def _parse_audio_code_string(self, code_str: str) -> List[int]:
|
| 777 |
+
"""Extract integer audio codes from prompt tokens like <|audio_code_123|>."""
|
|
|
|
|
|
|
|
|
|
| 778 |
if not code_str:
|
| 779 |
return []
|
| 780 |
try:
|
| 781 |
+
return [int(x) for x in re.findall(r"<\|audio_code_(\d+)\|>", code_str)]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 782 |
except Exception as e:
|
| 783 |
logger.debug(f"[_parse_audio_code_string] Failed to parse audio code string: {e}")
|
| 784 |
return []
|
|
|
|
| 800 |
detokenizer = self.model.detokenizer
|
| 801 |
|
| 802 |
# Get codebook size for validation
|
| 803 |
+
codebook_size = getattr(quantizer, 'codebook_size', 65536)
|
|
|
|
|
|
|
| 804 |
if hasattr(quantizer, 'quantizers') and len(quantizer.quantizers) > 0:
|
| 805 |
codebook_size = getattr(quantizer.quantizers[0], 'codebook_size', codebook_size)
|
| 806 |
|
| 807 |
+
# Validate code IDs are within valid range
|
| 808 |
+
invalid_codes = [c for c in code_ids if c < 0 or c >= codebook_size]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 809 |
if invalid_codes:
|
| 810 |
+
logger.warning(f"[_decode_audio_codes_to_latents] Found {len(invalid_codes)} invalid codes out of range [0, {codebook_size}): {invalid_codes[:5]}...")
|
| 811 |
+
# Clamp invalid codes to valid range
|
| 812 |
+
code_ids = [max(0, min(c, codebook_size - 1)) for c in code_ids]
|
| 813 |
|
| 814 |
num_quantizers = getattr(quantizer, "num_quantizers", 1)
|
| 815 |
# Create indices tensor: [T_5Hz]
|