diff --git "a/code/acestep/constrained_logits_processor.py" "b/code/acestep/constrained_logits_processor.py"
new file mode 100644--- /dev/null
+++ "b/code/acestep/constrained_logits_processor.py"
@@ -0,0 +1,2230 @@
+
+from enum import Enum, auto
+from typing import Optional, Dict, Any, Tuple, List, Callable, Set
+from loguru import logger
+from transformers import AutoTokenizer
+from transformers.generation.logits_process import LogitsProcessor
+import os
+import torch
+from acestep.constants import (
+ VALID_LANGUAGES,
+ KEYSCALE_NOTES,
+ KEYSCALE_ACCIDENTALS,
+ KEYSCALE_MODES,
+ VALID_KEYSCALES,
+ BPM_MIN,
+ BPM_MAX,
+ DURATION_MIN,
+ DURATION_MAX,
+ VALID_TIME_SIGNATURES,
+)
+
+
+# ==============================================================================
+# FSM States for Constrained Decoding
+# ==============================================================================
+class FSMState(Enum):
+ """Finite State Machine states for metadata generation"""
+ THINK_TAG = auto() # Generating ""
+ NEWLINE_AFTER_THINK = auto() # Generating "\n" after
+ BPM_NAME = auto() # Generating "bpm: "
+ BPM_VALUE = auto() # Generating numeric value 30-300
+ NEWLINE_AFTER_BPM = auto() # Generating "\n" after bpm value
+ CAPTION_NAME = auto() # Generating "caption: "
+ CAPTION_VALUE = auto() # Generating caption text (no code blocks/newlines)
+ DURATION_NAME = auto() # Generating "duration: "
+ DURATION_VALUE = auto() # Generating numeric value 10-600
+ NEWLINE_AFTER_DURATION = auto()
+ GENRES_NAME = auto() # Generating "genres: "
+ GENRES_VALUE = auto() # Generating any non-empty string
+ NEWLINE_AFTER_GENRES = auto()
+ KEYSCALE_NAME = auto() # Generating "keyscale: "
+ KEYSCALE_VALUE = auto() # Generating keyscale pattern
+ NEWLINE_AFTER_KEYSCALE = auto()
+ LANGUAGE_NAME = auto() # Generating "language: "
+ LANGUAGE_VALUE = auto() # Generating language code (en, zh, ja, etc.)
+ TIMESIG_NAME = auto() # Generating "timesignature: "
+ TIMESIG_VALUE = auto() # Generating 2, 3, 4, or 6
+ NEWLINE_AFTER_TIMESIG = auto()
+ THINK_END_TAG = auto() # Generating ""
+ CODES_GENERATION = auto() # Generating audio codes (no constraints)
+ COMPLETED = auto() # Generation completed
+
+
+class MetadataConstrainedLogitsProcessor(LogitsProcessor):
+ """
+ FSM-driven LogitsProcessor that constrains generation to produce valid metadata.
+
+ This processor enforces the following format:
+
+ bpm: [30-300]
+ caption: [text without code blocks, ends with period + newline]
+ duration: [10-600]
+ keyscale: [A-G][#/♭]? [major/minor]
+ language: [en/zh/ja/ko/es/fr/de/uk/ru/...]
+ timesignature: [2/3/4/6]
+
+
+ It uses token masking (setting invalid token logits to -inf) to enforce constraints.
+ For numeric fields, it uses early-blocking to prevent out-of-range values.
+ For field transitions (e.g., end of numeric value), it compares P(newline) vs P(digit).
+ For caption field, it blocks code blocks and newlines, and only transitions when
+ the previous token was a period and newline has the highest probability.
+ """
+
+ def __init__(
+ self,
+ tokenizer: AutoTokenizer,
+ enabled: bool = True,
+ debug: bool = False,
+ genres_vocab_path: Optional[str] = None,
+ skip_genres: bool = True,
+ ):
+ """
+ Initialize the constrained logits processor.
+
+ This processor should be initialized once when loading the LLM and reused
+ for all generations.
+ Args:
+ tokenizer: The tokenizer to use for encoding/decoding
+ enabled: Whether to enable constrained decoding
+ debug: Whether to print debug information
+ """
+ self.tokenizer = tokenizer
+ self.enabled = enabled
+ self.debug = debug
+ self.skip_genres = skip_genres
+ self.skip_caption = False # Set to True to skip caption field generation
+ self.skip_language = False # Set to True to skip language field generation
+ self.caption: Optional[str] = None # Set via update_caption() before each generation
+
+ # User-provided metadata fields (optional)
+ # If provided, these fields will be used directly instead of generating
+ # Format: {"bpm": "120", "caption": "...", "duration": "234", "keyscale": "G major", "language": "en", "timesignature": "4"}
+ self.user_provided_metadata: Dict[str, Optional[str]] = {
+ "bpm": None,
+ "caption": None,
+ "duration": None,
+ "keyscale": None,
+ "language": None,
+ "timesignature": None,
+ "genres": None,
+ }
+
+ # Temperature settings for different generation phases (set per-generation)
+ # If set, the processor will apply temperature scaling (divide logits by temperature)
+ # Note: Set base sampler temperature to 1.0 when using processor-based temperature
+ self.metadata_temperature: Optional[float] = None
+ self.codes_temperature: Optional[float] = None
+
+ # Duration constraint for codes generation
+ # 5 codes = 1 second, so target_codes = target_duration * 5
+ self.target_duration: Optional[float] = None # User-specified duration in seconds
+ self.target_codes: Optional[int] = None # Computed target codes count
+ self.codes_count: int = 0 # Counter for generated codes
+
+ # Stop at reasoning flag - if True, stop generation after tag
+ self.stop_at_reasoning: bool = False
+
+ # Generation phase - "cot" or "codes"
+ # Used to determine FSM behavior when prompt already contains CoT
+ self.generation_phase: str = "cot"
+
+ # Current state
+ self.state = FSMState.THINK_TAG
+ self.position_in_state = 0 # Position within current state's fixed string
+ self.accumulated_value = "" # For numeric/text value accumulation (legacy, for compatibility)
+ self.accumulated_token_ids: List[int] = [] # Token ID sequence for keyscale (and other fields)
+
+ # Caption generation state tracking
+ self.caption_after_newline = False # Track if we're right after a newline in caption
+ self.caption_token_count = 0 # Track token count for caption (max 512)
+ self.caption_ending = False # Track if caption is ending (after detecting non-indented line)
+ self.pending_field_name = "" # Accumulate field name tokens when caption is ending
+
+ # Token queue for user-provided fields (injected directly without generation)
+ self.user_field_token_queue: List[int] = []
+ self.current_user_field: Optional[str] = None # Current field being injected
+
+ # Pre-compute token IDs for efficiency
+ self._precompute_tokens()
+
+ # Genres vocabulary for constrained decoding
+ self.genres_vocab_path = genres_vocab_path or os.path.join(
+ os.path.dirname(os.path.abspath(__file__)), "genres_vocab.txt"
+ )
+ self.genres_vocab: List[str] = [] # Full vocab
+ self.genres_vocab_mtime: float = 0.0
+ self.genres_trie: Dict = {} # Trie for full vocab (fallback)
+ self.caption_genres_trie: Dict = {} # Trie for caption-matched genres (priority)
+ self.caption_matched_genres: List[str] = [] # Genres matched from caption
+
+ self._char_to_tokens: Dict[str, set] = {} # Precomputed char -> token IDs mapping
+
+ # Precompute token mappings once (O(vocab_size), runs once at init)
+ self._precompute_char_token_mapping()
+
+ # Field definitions (needed before building prefix trees)
+ self.field_specs = {
+ "bpm": {"min": BPM_MIN, "max": BPM_MAX},
+ "duration": {"min": DURATION_MIN, "max": DURATION_MAX},
+ "timesignature": {"valid_values": VALID_TIME_SIGNATURES},
+ }
+
+ # Build valid numeric values for BPM, Duration, Timesignature
+ # These will be used to build prefix trees based on actual tokenization
+ self.valid_bpm_values = [str(v) for v in range(self.field_specs["bpm"]["min"], self.field_specs["bpm"]["max"] + 1)]
+ self.valid_duration_values = [str(v) for v in range(self.field_specs["duration"]["min"], self.field_specs["duration"]["max"] + 1)]
+ self.valid_timesig_values = [str(v) for v in self.field_specs["timesignature"]["valid_values"]]
+
+ # Build keyscale prefix tree (requires _char_to_tokens to be initialized)
+ self.keyscale_prefix_tree = self._build_keyscale_prefix_tree()
+
+ # Build numeric prefix trees (BPM, Duration, Timesignature) with context
+ # IMPORTANT: State machine generates "bpm:" (no space), but tokenizer sees "bpm: " (with space)
+ # Use same logic as keyscale: context_prefix_for_matching (no space) and context_prefix_for_tokenization (with space)
+ self.bpm_prefix_tree = self._build_numeric_prefix_tree(
+ self.valid_bpm_values,
+ context_prefix_for_matching="bpm:",
+ context_prefix_for_tokenization="bpm: "
+ )
+ self.duration_prefix_tree = self._build_numeric_prefix_tree(
+ self.valid_duration_values,
+ context_prefix_for_matching="duration:",
+ context_prefix_for_tokenization="duration: "
+ )
+ self.timesig_prefix_tree = self._build_numeric_prefix_tree(
+ self.valid_timesig_values,
+ context_prefix_for_matching="timesignature:",
+ context_prefix_for_tokenization="timesignature: "
+ )
+
+ # Build language prefix tree (similar to keyscale but for language codes)
+ self.language_prefix_tree = self._build_language_prefix_tree()
+
+ self._load_genres_vocab()
+
+ # Fixed strings for each state
+ # IMPORTANT: Do NOT include trailing space after colon - tokenizer will handle spacing
+ # All matching should be done at token level, not string level
+ # NOTE: NEWLINE_AFTER_* states are removed - field values generate newline directly and transition to next field
+ self.fixed_strings = {
+ FSMState.THINK_TAG: "",
+ FSMState.NEWLINE_AFTER_THINK: "\n",
+ FSMState.BPM_NAME: "bpm:",
+ FSMState.CAPTION_NAME: "caption:",
+ FSMState.DURATION_NAME: "duration:",
+ FSMState.GENRES_NAME: "genres:",
+ FSMState.KEYSCALE_NAME: "keyscale:",
+ FSMState.LANGUAGE_NAME: "language:",
+ FSMState.TIMESIG_NAME: "timesignature:",
+ FSMState.THINK_END_TAG: "",
+ }
+
+ # State transitions
+ self._build_state_transitions()
+
+ def _get_next_field_state(self, current_field: str) -> Optional[FSMState]:
+ """
+ Get the next field state. Always returns the next field's NAME state,
+ even if the field is user-provided (we still need to generate the field name).
+
+ Args:
+ current_field: Current field name ("bpm", "caption", "duration", "genres", "keyscale", "language", "timesignature")
+
+ Returns:
+ Next FSMState (NAME state of next field), or THINK_END_TAG if no more fields
+ """
+ # New field order: bpm -> caption -> duration -> keyscale -> language -> timesignature
+ # genres is optional and can be skipped
+ field_order = ["bpm", "caption", "duration", "genres", "keyscale", "language", "timesignature"]
+ field_to_state = {
+ "bpm": FSMState.BPM_NAME,
+ "caption": FSMState.CAPTION_NAME,
+ "duration": FSMState.DURATION_NAME,
+ "genres": FSMState.GENRES_NAME,
+ "keyscale": FSMState.KEYSCALE_NAME,
+ "language": FSMState.LANGUAGE_NAME,
+ "timesignature": FSMState.TIMESIG_NAME,
+ }
+
+ try:
+ current_idx = field_order.index(current_field)
+ except ValueError:
+ return FSMState.THINK_END_TAG
+
+ # Find next field in order
+ for i in range(current_idx + 1, len(field_order)):
+ field = field_order[i]
+
+ # Skip fields based on flags
+ if field == "genres" and self.skip_genres:
+ continue
+ if field == "caption" and self.skip_caption:
+ continue
+ if field == "language" and self.skip_language:
+ continue
+
+ # Return the next field's NAME state (even if user-provided, we still generate field name)
+ return field_to_state[field]
+
+ # No more fields, go to THINK_END_TAG
+ return FSMState.THINK_END_TAG
+
+ def _build_state_transitions(self):
+ """Build state transition map based on user-provided metadata."""
+ self.next_state = {
+ FSMState.THINK_TAG: FSMState.NEWLINE_AFTER_THINK,
+ FSMState.NEWLINE_AFTER_THINK: FSMState.BPM_NAME, # Always start with BPM
+ FSMState.THINK_END_TAG: FSMState.CODES_GENERATION,
+ FSMState.CODES_GENERATION: FSMState.COMPLETED,
+ }
+
+ # Build transitions for all fields (even if user-provided, we still need to generate field name)
+ # Field order: bpm -> caption -> duration -> genres -> keyscale -> language -> timesignature
+
+ # BPM field: NAME -> VALUE -> next field (caption or duration)
+ self.next_state[FSMState.BPM_NAME] = FSMState.BPM_VALUE
+ self.next_state[FSMState.BPM_VALUE] = self._get_next_field_state("bpm")
+
+ # Caption field (only if not skipped): NAME -> VALUE -> next field (duration)
+ if not self.skip_caption:
+ self.next_state[FSMState.CAPTION_NAME] = FSMState.CAPTION_VALUE
+ self.next_state[FSMState.CAPTION_VALUE] = self._get_next_field_state("caption")
+
+ # Duration field: NAME -> VALUE -> next field
+ self.next_state[FSMState.DURATION_NAME] = FSMState.DURATION_VALUE
+ self.next_state[FSMState.DURATION_VALUE] = self._get_next_field_state("duration")
+
+ # Genres field (only if not skipped): NAME -> VALUE -> next field
+ if not self.skip_genres:
+ self.next_state[FSMState.GENRES_NAME] = FSMState.GENRES_VALUE
+ self.next_state[FSMState.GENRES_VALUE] = self._get_next_field_state("genres")
+
+ # Keyscale field: NAME -> VALUE -> next field (language or timesignature)
+ self.next_state[FSMState.KEYSCALE_NAME] = FSMState.KEYSCALE_VALUE
+ self.next_state[FSMState.KEYSCALE_VALUE] = self._get_next_field_state("keyscale")
+
+ # Language field (only if not skipped): NAME -> VALUE -> next field (timesignature)
+ if not self.skip_language:
+ self.next_state[FSMState.LANGUAGE_NAME] = FSMState.LANGUAGE_VALUE
+ self.next_state[FSMState.LANGUAGE_VALUE] = self._get_next_field_state("language")
+
+ # Timesignature field: NAME -> VALUE -> THINK_END_TAG
+ self.next_state[FSMState.TIMESIG_NAME] = FSMState.TIMESIG_VALUE
+ self.next_state[FSMState.TIMESIG_VALUE] = FSMState.THINK_END_TAG
+
+ def set_skip_genres(self, skip: bool):
+ """Set whether to skip genres generation and rebuild state transitions."""
+ self.skip_genres = skip
+ self._build_state_transitions()
+
+ def set_skip_caption(self, skip: bool):
+ """Set whether to skip caption generation and rebuild state transitions."""
+ self.skip_caption = skip
+ self._build_state_transitions()
+
+ def set_skip_language(self, skip: bool):
+ """Set whether to skip language generation and rebuild state transitions."""
+ self.skip_language = skip
+ self._build_state_transitions()
+
+ @staticmethod
+ def postprocess_caption(caption: str) -> str:
+ """
+ Post-process caption to remove YAML multi-line formatting.
+ Converts YAML-style multi-line text (with newlines and leading spaces)
+ to a single-line string.
+
+ Example:
+ Input: "An emotional ballad.\\n The track opens with piano.\\n More text."
+ Output: "An emotional ballad. The track opens with piano. More text."
+
+ Args:
+ caption: Raw caption text with possible YAML formatting
+
+ Returns:
+ Clean single-line caption
+ """
+ if not caption:
+ return caption
+
+ # Split by newlines
+ lines = caption.split('\n')
+
+ # Process each line: strip leading/trailing whitespace
+ cleaned_lines = []
+ for line in lines:
+ stripped = line.strip()
+ if stripped:
+ cleaned_lines.append(stripped)
+
+ # Join with single space
+ return ' '.join(cleaned_lines)
+
+ def set_stop_at_reasoning(self, stop: bool):
+ """
+ Set whether to stop generation after tag.
+
+ Args:
+ stop: If True, generation will stop immediately after tag is generated.
+ If False, generation continues to codes generation phase.
+ """
+ self.stop_at_reasoning = stop
+
+ def set_generation_phase(self, phase: str):
+ """
+ Set the generation phase.
+
+ Args:
+ phase: "cot" for CoT metadata generation, "codes" for audio codes generation,
+ or "understand" for audio understanding (codes → metadata + lyrics).
+ When phase is "codes" and the input prompt already contains ,
+ the FSM will skip metadata generation and go directly to codes generation.
+ When phase is "understand", generate CoT metadata then free-form lyrics.
+ """
+ if phase not in ("cot", "codes", "understand"):
+ raise ValueError(f"Invalid generation phase: {phase!r}. Must be 'cot', 'codes', or 'understand'")
+ self.generation_phase = phase
+
+ def set_user_metadata(self, metadata: Optional[Dict[str, Optional[str]]] = None):
+ """
+ Set user-provided metadata fields. Fields that are provided will be used directly
+ instead of generating. Fields that are None will be generated.
+
+ Args:
+ metadata: Dictionary with optional fields:
+ - "bpm": Optional[str] - e.g., "120"
+ - "caption": Optional[str] - e.g., "A melodic piano piece..."
+ - "duration": Optional[str] - e.g., "234"
+ - "keyscale": Optional[str] - e.g., "G major"
+ - "language": Optional[str] - e.g., "en"
+ - "timesignature": Optional[str] - e.g., "4"
+ - "genres": Optional[str] - e.g., "Pop Rock"
+ If None, clears all user-provided metadata.
+ """
+ if metadata is None:
+ metadata = {}
+
+ # Update user-provided metadata
+ for field in ["bpm", "caption", "duration", "keyscale", "language", "timesignature", "genres"]:
+ if field in metadata:
+ self.user_provided_metadata[field] = metadata[field]
+ else:
+ self.user_provided_metadata[field] = None
+
+ # Rebuild state transitions to skip provided fields
+ self._build_state_transitions()
+
+ if self.debug:
+ provided_fields = [k for k, v in self.user_provided_metadata.items() if v is not None]
+ if provided_fields:
+ logger.debug(f"User provided metadata fields: {provided_fields}")
+ else:
+ logger.debug("No user-provided metadata, all fields will be generated")
+
+ def _precompute_tokens(self):
+ """Pre-compute commonly used token IDs for efficiency."""
+ # Digit tokens (0-9)
+ self.digit_tokens = {}
+ for d in range(10):
+ tokens = self.tokenizer.encode(str(d), add_special_tokens=False)
+ if tokens:
+ self.digit_tokens[d] = tokens[-1] # Take last token (in case of prefix)
+
+ # Newline token
+ newline_tokens = self.tokenizer.encode("\n", add_special_tokens=False)
+ self.newline_token = newline_tokens[-1] if newline_tokens else None
+
+ # Note tokens for keyscale (A-G)
+ self.note_tokens = {}
+ for note in KEYSCALE_NOTES:
+ tokens = self.tokenizer.encode(note, add_special_tokens=False)
+ if tokens:
+ self.note_tokens[note] = tokens[-1]
+
+ # Sharp/flat tokens
+ self.sharp_tokens = []
+ for s in ["#", "♯"]:
+ tokens = self.tokenizer.encode(s, add_special_tokens=False)
+ if tokens:
+ self.sharp_tokens.append(tokens[-1])
+
+ self.flat_tokens = []
+ for f in ["b", "♭"]:
+ tokens = self.tokenizer.encode(f, add_special_tokens=False)
+ if tokens:
+ self.flat_tokens.append(tokens[-1])
+
+ # Space token
+ space_tokens = self.tokenizer.encode(" ", add_special_tokens=False)
+ self.space_token = space_tokens[-1] if space_tokens else None
+
+ # Major/minor tokens (we'll encode the full words)
+ self.major_start_tokens = []
+ self.minor_start_tokens = []
+ for prefix in ["m", "M"]:
+ tokens = self.tokenizer.encode(prefix, add_special_tokens=False)
+ if tokens:
+ if prefix.lower() == "m":
+ self.minor_start_tokens.append(tokens[-1])
+ self.major_start_tokens.append(tokens[-1]) # "major" also starts with m
+
+ # Vocab size
+ self.vocab_size = len(self.tokenizer)
+
+ # Comma token for multi-genre support
+ comma_tokens = self.tokenizer.encode(",", add_special_tokens=False)
+ self.comma_token = comma_tokens[-1] if comma_tokens else None
+
+ # EOS token for duration-constrained codes generation
+ self.eos_token_id = self.tokenizer.eos_token_id
+
+ # Period token for caption field transition logic
+ period_tokens = self.tokenizer.encode(".", add_special_tokens=False)
+ self.period_token = period_tokens[-1] if period_tokens else None
+
+ # Backtick tokens for blocking code blocks in caption
+ backtick_tokens = self.tokenizer.encode("`", add_special_tokens=False)
+ self.backtick_token = backtick_tokens[-1] if backtick_tokens else None
+
+ # Valid language codes (ISO 639-1 and common variants)
+ self.valid_languages = VALID_LANGUAGES
+
+ # Precompute audio code token IDs (tokens matching <|audio_code_\d+|>)
+ # These should be blocked during caption generation
+ self.audio_code_token_ids: Set[int] = set()
+ self._precompute_audio_code_tokens()
+
+ # Precompute audio code mask for efficient blocking (O(1) instead of O(n))
+ # This mask will be added to scores during caption generation
+ self.audio_code_mask: Optional[torch.Tensor] = None
+ # Inverse mask: block all non-audio-code tokens (for CODES_GENERATION state)
+ self.non_audio_code_mask: Optional[torch.Tensor] = None
+ self._build_audio_code_mask()
+
+ # Build valid keyscales set (prefix tree will be built after _char_to_tokens is initialized)
+ # 7 notes × 5 accidentals (none, #, b, ♯, ♭) × 2 modes = 70 valid combinations
+ self.valid_keyscales = VALID_KEYSCALES.copy()
+
+ # keyscale_prefix_tree will be built in _precompute_char_token_mapping() after _char_to_tokens is ready
+ # Numeric prefix trees will be built after field_specs is defined
+
+ def _precompute_audio_code_tokens(self):
+ """
+ Precompute audio code token IDs (tokens matching <|audio_code_\\d+|>).
+ These tokens should be blocked during caption generation.
+ """
+ import re
+ audio_code_pattern = re.compile(r'^<\|audio_code_\d+\|>$')
+
+ # Iterate through vocabulary to find audio code tokens
+ for token_id in range(self.vocab_size):
+ try:
+ token_text = self.tokenizer.decode([token_id])
+ if audio_code_pattern.match(token_text):
+ self.audio_code_token_ids.add(token_id)
+ except Exception:
+ continue
+
+ if self.debug:
+ logger.debug(f"Found {len(self.audio_code_token_ids)} audio code tokens")
+
+ def _build_audio_code_mask(self):
+ """
+ Build a precomputed mask tensor for blocking audio code tokens.
+ This mask can be added to scores in O(1) time instead of O(n) loop.
+
+ The mask is [1, vocab_size] tensor with -inf at audio code token positions.
+
+ Also builds the inverse mask (non_audio_code_mask) for CODES_GENERATION state,
+ which blocks all non-audio-code tokens.
+ """
+ if not self.audio_code_token_ids:
+ self.audio_code_mask = None
+ self.non_audio_code_mask = None
+ return
+
+ # Create mask tensor: 0 everywhere, -inf at audio code positions
+ # Use float32 for compatibility with most model dtypes
+ mask = torch.zeros(1, self.vocab_size, dtype=torch.float32)
+
+ # Convert set to list for indexing
+ audio_code_indices = list(self.audio_code_token_ids)
+
+ # Set -inf at audio code token positions
+ mask[0, audio_code_indices] = float('-inf')
+
+ self.audio_code_mask = mask
+
+ # Build inverse mask: -inf everywhere EXCEPT at audio code positions
+ # This is used in CODES_GENERATION state to only allow audio codes
+ inverse_mask = torch.full((1, self.vocab_size), float('-inf'), dtype=torch.float32)
+ inverse_mask[0, audio_code_indices] = 0
+
+ # Also allow EOS token in codes generation (will be controlled by duration constraint)
+ if self.eos_token_id is not None:
+ inverse_mask[0, self.eos_token_id] = 0
+
+ self.non_audio_code_mask = inverse_mask
+
+ if self.debug:
+ logger.debug(f"Built audio code masks for {len(self.audio_code_token_ids)} tokens")
+
+ def _apply_whitelist_inplace(self, scores: torch.Tensor, allowed_tokens: List[int]) -> None:
+ """
+ Apply whitelist constraint inplace: only allow specified tokens, block all others.
+
+ This is more efficient than creating a mask tensor because:
+ 1. No memory allocation for mask
+ 2. No tensor addition operation
+
+ Args:
+ scores: [1, vocab_size] scores tensor to modify inplace
+ allowed_tokens: List of token IDs to allow (all others will be set to -inf)
+ """
+ if not allowed_tokens:
+ # No tokens allowed, set all to -inf
+ scores.fill_(float('-inf'))
+ return
+
+ # Save the original values of allowed tokens
+ allowed_indices = torch.tensor(allowed_tokens, device=scores.device, dtype=torch.long)
+ saved_values = scores[0, allowed_indices].clone()
+
+ # Set all scores to -inf
+ scores.fill_(float('-inf'))
+
+ # Restore allowed token values
+ scores[0, allowed_indices] = saved_values
+
+ def _build_keyscale_prefix_tree(self) -> Dict[Tuple[int, ...], Set[int]]:
+ """
+ Build keyscale prefix to allowed tokens mapping based on ACTUAL tokenization.
+
+ IMPORTANT: Uses token ID sequences as keys, NOT strings, to avoid tokenization mismatches.
+
+ CRITICAL FIX: The tokenizer may merge the context's trailing space into the next token.
+ For example:
+ - "keyscale: " tokenizes to [10563, 2246, 25, 220] -> ['keys', 'cale', ':', ' ']
+ - "keyscale: G major" tokenizes to [10563, 2246, 25, 479, 3598] -> ['keys', 'cale', ':', ' G', ' major']
+ The space ' ' (220) is merged into ' G' (479), so we can't use simple slicing.
+
+ Strategy:
+ 1. For each keyscale (e.g., "G major"), encode the FULL string "keyscale: G major"
+ 2. Tokenize to get: [10563, 2246, 25, 479, 3598] -> ['keys', 'cale', ':', ' G', ' major']
+ 3. Find where context prefix ends by matching token sequences (handling space merging)
+ 4. Extract keyscale value tokens: [479, 3598] (for "G major")
+ 5. Build prefix tree using token ID sequences as keys
+
+ This ensures we get the exact tokenization that occurs during generation.
+ """
+ prefix_to_tokens: Dict[Tuple[int, ...], Set[int]] = {}
+
+ # Context prefix that appears before keyscale value
+ # IMPORTANT: The state machine generates "keyscale:" (no space), but when tokenizing
+ # the full string "keyscale: G major", the tokenizer includes space, so we need to
+ # match the actual tokenization behavior.
+ #
+ # Strategy:
+ # 1. Use "keyscale:" (no space) to match the state machine's output
+ # 2. But when building prefix tree, use "keyscale: " (with space) + keyscale to match actual tokenization
+ context_prefix_for_matching = "keyscale:" # What state machine generates
+ context_prefix_for_tokenization = "keyscale: " # What tokenizer sees in full string
+
+ # First, tokenize the context (without space) to know its token sequence for matching
+ context_token_ids = self.tokenizer.encode(context_prefix_for_matching, add_special_tokens=False)
+
+ if self.debug:
+ context_tokens_str = [self.tokenizer.decode([t]) for t in context_token_ids]
+ logger.debug(f"Context for matching 'keyscale:' tokenizes to {context_token_ids} -> {context_tokens_str}")
+
+ # For each valid keyscale, encode full string and extract value tokens
+ for keyscale in self.valid_keyscales:
+ # Step 1: Encode full string "keyscale: {keyscale}" (with space, as tokenizer sees it)
+ full_text = context_prefix_for_tokenization + keyscale
+ full_token_ids = self.tokenizer.encode(full_text, add_special_tokens=False)
+
+ # Step 2: Find where context ends in full_token_ids
+ # We match using context_prefix_for_matching ("keyscale:") token sequence
+ # because that's what the state machine actually generates
+ context_end_idx = None
+
+ # Try exact prefix match using context_prefix_for_matching token sequence
+ if len(full_token_ids) >= len(context_token_ids):
+ if full_token_ids[:len(context_token_ids)] == context_token_ids:
+ context_end_idx = len(context_token_ids)
+
+ if context_end_idx is None:
+ if self.debug:
+ logger.warning(f"Could not find context prefix in full tokenization of '{full_text}', skipping")
+ continue
+
+ # Step 3: Extract keyscale value tokens (everything after context)
+ keyscale_token_ids = full_token_ids[context_end_idx:]
+
+ # Step 4: Verify we extracted some tokens (sanity check)
+ if not keyscale_token_ids:
+ if self.debug:
+ logger.warning(f"No tokens extracted for keyscale '{keyscale}', skipping")
+ continue
+
+ # Step 5: Verify first token is a note (A-G)
+ # This is critical: the first token of keyscale value must be a note
+ first_token_id = keyscale_token_ids[0]
+ first_token_str = self.tokenizer.decode([first_token_id])
+ # Check if first token starts with a note (A-G, case insensitive, with optional leading space)
+ first_char = first_token_str.lstrip()[0].upper() if first_token_str.lstrip() else ""
+ if first_char not in "ABCDEFG":
+ # This keyscale's first token is not a note - skip it
+ if self.debug:
+ logger.debug(f"Skipping keyscale '{keyscale}': first token is '{first_token_str}' (id={first_token_id}), not a note")
+ continue
+
+ # Step 6: Build prefix mappings from keyscale value tokens
+ # Use token ID sequences as keys (not strings) to avoid tokenization mismatches
+ for i in range(len(keyscale_token_ids) + 1):
+ # Current token sequence prefix (empty tuple for start)
+ token_prefix = tuple(keyscale_token_ids[:i])
+
+ if token_prefix not in prefix_to_tokens:
+ prefix_to_tokens[token_prefix] = set()
+
+ if i < len(keyscale_token_ids):
+ # Add next token as allowed for current prefix
+ next_token_id = keyscale_token_ids[i]
+ prefix_to_tokens[token_prefix].add(next_token_id)
+ else:
+ # Complete keyscale should allow newline
+ if self.newline_token:
+ prefix_to_tokens[token_prefix].add(self.newline_token)
+
+ if self.debug:
+ logger.debug(f"Built keyscale prefix tree with {len(prefix_to_tokens)} token sequence prefixes")
+ # Check empty prefix (start of keyscale value)
+ empty_prefix = tuple()
+ if empty_prefix in prefix_to_tokens:
+ first_tokens = prefix_to_tokens[empty_prefix]
+ decoded_first = [(t, repr(self.tokenizer.decode([t]))) for t in sorted(first_tokens)]
+ logger.debug(f"First tokens allowed (empty prefix): {decoded_first}")
+
+ return prefix_to_tokens
+
+ def _build_numeric_prefix_tree(
+ self,
+ valid_values: List[str],
+ context_prefix_for_matching: str = "",
+ context_prefix_for_tokenization: str = ""
+ ) -> Dict[Tuple[int, ...], Set[int]]:
+ """
+ Build prefix tree for numeric field based on actual tokenization with context.
+
+ IMPORTANT: Uses token ID sequences as keys, NOT strings, to avoid tokenization mismatches.
+
+ Args:
+ valid_values: List of valid numeric strings (e.g., ["30", "31", ..., "300"])
+ context_prefix_for_matching: Context string that state machine generates (e.g., "bpm:") - no space
+ context_prefix_for_tokenization: Context string for tokenization (e.g., "bpm: ") - with space
+
+ Returns:
+ Dict mapping token ID sequence prefix -> set of allowed token IDs
+ """
+ prefix_to_tokens: Dict[Tuple[int, ...], Set[int]] = {}
+
+ # Encode context for matching (what state machine generates, no space)
+ context_token_ids = self.tokenizer.encode(context_prefix_for_matching, add_special_tokens=False) if context_prefix_for_matching else []
+
+ # For each valid value, encode it with context and build prefix mappings
+ for value_str in valid_values:
+ # Encode value WITH context (with space) to match actual tokenization
+ full_text = context_prefix_for_tokenization + value_str
+ token_ids = self.tokenizer.encode(full_text, add_special_tokens=False)
+
+ # Find where context ends in full_token_ids using context_prefix_for_matching token sequence
+ context_end_idx = None
+ if len(token_ids) >= len(context_token_ids):
+ if token_ids[:len(context_token_ids)] == context_token_ids:
+ context_end_idx = len(context_token_ids)
+
+ if context_end_idx is None:
+ if self.debug:
+ logger.warning(f"Could not find context prefix in full tokenization of '{full_text}', skipping")
+ continue
+
+ # Extract only tokens that belong to the value itself (skip context tokens)
+ value_token_ids = token_ids[context_end_idx:]
+
+ # Build prefix mappings using token ID sequences as keys
+ for i in range(len(value_token_ids) + 1):
+ # Current token sequence prefix (empty tuple for start)
+ token_prefix = tuple(value_token_ids[:i])
+
+ if token_prefix not in prefix_to_tokens:
+ prefix_to_tokens[token_prefix] = set()
+
+ if i < len(value_token_ids):
+ # Add next token as allowed for current prefix
+ next_token_id = value_token_ids[i]
+ prefix_to_tokens[token_prefix].add(next_token_id)
+ else:
+ # Complete value should allow newline
+ if self.newline_token:
+ prefix_to_tokens[token_prefix].add(self.newline_token)
+
+ return prefix_to_tokens
+
+ def _build_language_prefix_tree(self) -> Dict[Tuple[int, ...], Set[int]]:
+ """
+ Build language prefix to allowed tokens mapping based on ACTUAL tokenization.
+ Similar to keyscale prefix tree but for language codes.
+
+ Uses token ID sequences as keys, NOT strings, to avoid tokenization mismatches.
+ """
+ prefix_to_tokens: Dict[Tuple[int, ...], Set[int]] = {}
+
+ context_prefix_for_matching = "language:"
+ context_prefix_for_tokenization = "language: "
+
+ context_token_ids = self.tokenizer.encode(context_prefix_for_matching, add_special_tokens=False)
+
+ if self.debug:
+ context_tokens_str = [self.tokenizer.decode([t]) for t in context_token_ids]
+ logger.debug(f"Context for matching 'language:' tokenizes to {context_token_ids} -> {context_tokens_str}")
+
+ for lang in self.valid_languages:
+ full_text = context_prefix_for_tokenization + lang
+ full_token_ids = self.tokenizer.encode(full_text, add_special_tokens=False)
+
+ context_end_idx = None
+ if len(full_token_ids) >= len(context_token_ids):
+ if full_token_ids[:len(context_token_ids)] == context_token_ids:
+ context_end_idx = len(context_token_ids)
+
+ if context_end_idx is None:
+ if self.debug:
+ logger.warning(f"Could not find context prefix in full tokenization of '{full_text}', skipping")
+ continue
+
+ lang_token_ids = full_token_ids[context_end_idx:]
+
+ if not lang_token_ids:
+ if self.debug:
+ logger.warning(f"No tokens extracted for language '{lang}', skipping")
+ continue
+
+ for i in range(len(lang_token_ids) + 1):
+ token_prefix = tuple(lang_token_ids[:i])
+
+ if token_prefix not in prefix_to_tokens:
+ prefix_to_tokens[token_prefix] = set()
+
+ if i < len(lang_token_ids):
+ next_token_id = lang_token_ids[i]
+ prefix_to_tokens[token_prefix].add(next_token_id)
+ else:
+ if self.newline_token:
+ prefix_to_tokens[token_prefix].add(self.newline_token)
+
+ if self.debug:
+ logger.debug(f"Built language prefix tree with {len(prefix_to_tokens)} token sequence prefixes")
+ empty_prefix = tuple()
+ if empty_prefix in prefix_to_tokens:
+ first_tokens = prefix_to_tokens[empty_prefix]
+ decoded_first = [(t, repr(self.tokenizer.decode([t]))) for t in sorted(first_tokens)]
+ logger.debug(f"First tokens allowed for language (empty prefix): {decoded_first}")
+
+ return prefix_to_tokens
+
+ def diagnose_keyscale_prefix_tree(self):
+ """
+ Diagnose the keyscale prefix tree to help debug generation bias.
+ Call this method to print detailed information about allowed tokens at each prefix.
+ """
+ print("=" * 60)
+ print("KEYSCALE PREFIX TREE DIAGNOSIS")
+ print("=" * 60)
+
+ # Check empty prefix (first token)
+ if "" in self.keyscale_prefix_tree:
+ first_tokens = self.keyscale_prefix_tree[""]
+ print(f"\n[Empty prefix] Allowed first tokens ({len(first_tokens)} total):")
+ for t in sorted(first_tokens):
+ decoded = self.tokenizer.decode([t])
+ print(f" Token {t}: {repr(decoded)}")
+ else:
+ print("\nWARNING: Empty prefix not in tree!")
+
+ # Check some common prefixes
+ test_prefixes = ["A", "B", "C", "D", "E", "F", "G"]
+ for prefix in test_prefixes:
+ # Try both with and without potential tokenizer artifacts
+ for test_key in [prefix, prefix + " "]:
+ if test_key in self.keyscale_prefix_tree:
+ tokens = self.keyscale_prefix_tree[test_key]
+ print(f"\n[Prefix {repr(test_key)}] Allowed tokens ({len(tokens)}):")
+ for t in sorted(tokens):
+ decoded = self.tokenizer.decode([t])
+ print(f" Token {t}: {repr(decoded)}")
+
+ # Show some complete keyscales that should be valid
+ print(f"\n[Valid keyscales] Total: {len(self.valid_keyscales)}")
+ sample = sorted(list(self.valid_keyscales))[:10]
+ for ks in sample:
+ print(f" {repr(ks)}")
+
+ print("=" * 60)
+
+
+ def _load_genres_vocab(self):
+ """
+ Load genres vocabulary from file. Supports hot reload by checking file mtime.
+ File format: one genre per line, lines starting with # are comments.
+ """
+ if not os.path.exists(self.genres_vocab_path):
+ if self.debug:
+ logger.debug(f"Genres vocab file not found: {self.genres_vocab_path}")
+ return
+
+ try:
+ mtime = os.path.getmtime(self.genres_vocab_path)
+ if mtime <= self.genres_vocab_mtime:
+ return # File hasn't changed
+
+ with open(self.genres_vocab_path, 'r', encoding='utf-8') as f:
+ genres = []
+ for line in f:
+ line = line.strip()
+ if line and not line.startswith('#'):
+ genres.append(line.lower())
+
+ self.genres_vocab = genres
+ self.genres_vocab_mtime = mtime
+ self._build_genres_trie()
+
+ if self.debug:
+ logger.debug(f"Loaded {len(self.genres_vocab)} genres from {self.genres_vocab_path}")
+ except Exception as e:
+ logger.warning(f"Failed to load genres vocab: {e}")
+
+ def _build_genres_trie(self):
+ """
+ Build a trie (prefix tree) from genres vocabulary for efficient prefix matching.
+ Each node is a dict with:
+ - '_end': True if this node represents a complete genre
+ - other keys: next characters in the trie
+ """
+ self.genres_trie = {}
+
+ for genre in self.genres_vocab:
+ node = self.genres_trie
+ for char in genre:
+ if char not in node:
+ node[char] = {}
+ node = node[char]
+ node['_end'] = True # Mark end of a complete genre
+
+ if self.debug:
+ logger.debug(f"Built genres trie with {len(self.genres_vocab)} entries")
+
+ def _extract_caption_genres(self, caption: str):
+ """
+ Extract genres from the user's caption that match entries in the vocabulary.
+ This creates a smaller trie for faster and more relevant genre generation.
+
+ Strategy (optimized - O(words * max_genre_len) instead of O(vocab_size)):
+ 1. Extract words/phrases from caption
+ 2. For each word, use trie to find all vocab entries that START with this word
+ 3. Build a separate trie from matched genres
+ """
+ if not caption or not self.genres_vocab:
+ return
+
+ caption_lower = caption.lower()
+ matched_genres = set()
+
+ # Extract words from caption (split by common delimiters)
+ import re
+ words = re.split(r'[,\s\-_/\\|]+', caption_lower)
+ words = [w.strip() for w in words if w.strip() and len(w.strip()) >= 2]
+
+ # For each word, find genres in trie that start with this word
+ for word in words:
+ # Find all genres starting with this word using trie traversal
+ node = self._get_genres_trie_node(word)
+ if node is not None:
+ # Collect all complete genres under this node
+ self._collect_complete_genres(node, word, matched_genres)
+
+ # Also check if any word appears as a substring in short genres (< 20 chars)
+ # This is a quick check for common single-word genres
+ genres_set = set(self.genres_vocab)
+ for word in words:
+ if word in genres_set:
+ matched_genres.add(word)
+
+ if not matched_genres:
+ if self.debug:
+ logger.debug(f"No genres matched in caption, using full vocab")
+ return
+
+ # Build a trie from matched genres
+ self.caption_matched_genres = list(matched_genres)
+ self.caption_genres_trie = {}
+
+ for genre in matched_genres:
+ node = self.caption_genres_trie
+ for char in genre:
+ if char not in node:
+ node[char] = {}
+ node = node[char]
+ node['_end'] = True
+
+ if self.debug:
+ logger.debug(f"Matched {len(matched_genres)} genres from caption: {list(matched_genres)[:5]}...")
+
+ def _collect_complete_genres(self, node: Dict, prefix: str, result: set, max_depth: int = 50):
+ """
+ Recursively collect all complete genres under a trie node.
+ Limited depth to avoid too many matches.
+ """
+ if max_depth <= 0:
+ return
+
+ if node.get('_end', False):
+ result.add(prefix)
+
+ # Limit total collected genres to avoid slowdown
+ if len(result) >= 100:
+ return
+
+ for char, child_node in node.items():
+ if char not in ('_end', '_tokens'):
+ self._collect_complete_genres(child_node, prefix + char, result, max_depth - 1)
+
+ def _precompute_char_token_mapping(self):
+ """
+ Precompute mapping from characters to token IDs and token decoded texts.
+ This allows O(1) lookup instead of calling tokenizer.encode()/decode() at runtime.
+
+ Time complexity: O(vocab_size) - runs once during initialization
+
+ Note: Many subword tokenizers (like Qwen) add space prefixes to tokens.
+ We need to handle both the raw first char and the first non-space char.
+ """
+ self._char_to_tokens: Dict[str, set] = {}
+ self._token_to_text: Dict[int, str] = {} # Precomputed decoded text for each token
+
+ # For each token in vocabulary, get its decoded text
+ for token_id in range(self.vocab_size):
+ try:
+ text = self.tokenizer.decode([token_id])
+
+ if not text:
+ continue
+
+ # Store the decoded text (normalized to lowercase)
+ # Keep leading spaces for proper concatenation (e.g., " rock" in "pop rock")
+ # Only rstrip trailing whitespace, unless it's a pure whitespace token
+ text_lower = text.lower()
+ if text_lower.strip(): # Has non-whitespace content
+ normalized_text = text_lower.rstrip()
+ else: # Pure whitespace token
+ normalized_text = " " # Normalize to single space
+ self._token_to_text[token_id] = normalized_text
+
+ # Map first character (including space) to this token
+ first_char = text[0].lower()
+ if first_char not in self._char_to_tokens:
+ self._char_to_tokens[first_char] = set()
+ self._char_to_tokens[first_char].add(token_id)
+
+ # Also map first non-space character to this token
+ # This handles tokenizers that add space prefixes (e.g., " pop" -> maps to 'p')
+ stripped_text = text.lstrip()
+ if stripped_text and stripped_text != text:
+ first_nonspace_char = stripped_text[0].lower()
+ if first_nonspace_char not in self._char_to_tokens:
+ self._char_to_tokens[first_nonspace_char] = set()
+ self._char_to_tokens[first_nonspace_char].add(token_id)
+
+ except Exception:
+ continue
+
+ if self.debug:
+ logger.debug(f"Precomputed char->token mapping for {len(self._char_to_tokens)} unique characters")
+
+ def _try_reload_genres_vocab(self):
+ """Check if genres vocab file has been updated and reload if necessary."""
+ if not os.path.exists(self.genres_vocab_path):
+ return
+
+ try:
+ mtime = os.path.getmtime(self.genres_vocab_path)
+ if mtime > self.genres_vocab_mtime:
+ self._load_genres_vocab()
+ except Exception:
+ pass # Ignore errors during hot reload check
+
+ def _get_genres_trie_node(self, prefix: str) -> Optional[Dict]:
+ """
+ Get the trie node for a given prefix.
+ Returns None if the prefix is not valid (no genres start with this prefix).
+ """
+ node = self.genres_trie
+ for char in prefix.lower():
+ if char not in node:
+ return None
+ node = node[char]
+ return node
+
+ def _is_complete_genre(self, text: str) -> bool:
+ """Check if the given text is a complete genre in the vocabulary."""
+ node = self._get_genres_trie_node(text.strip())
+ return node is not None and node.get('_end', False)
+
+ def _get_trie_node_from_trie(self, trie: Dict, prefix: str) -> Optional[Dict]:
+ """Get a trie node from a specific trie (helper for caption vs full trie)."""
+ node = trie
+ for char in prefix.lower():
+ if char not in node:
+ return None
+ node = node[char]
+ return node
+
+ def _get_allowed_genres_tokens(self) -> List[int]:
+ """
+ Get allowed tokens for genres field based on trie matching.
+
+ The entire genres string (including commas) must match a complete entry in the vocab.
+ For example, if vocab contains "pop, rock, jazz", the generated string must exactly
+ match that entry - we don't treat commas as separators for individual genres.
+
+ Strategy:
+ 1. If caption-matched genres exist, use that smaller trie first (faster + more relevant)
+ 2. If no caption matches or prefix not in caption trie, fallback to full vocab trie
+ 3. Get valid next characters from current trie node
+ 4. For each candidate token, verify the full decoded text forms a valid trie prefix
+ """
+ if not self.genres_vocab:
+ # No vocab loaded, allow all except newline if empty
+ return []
+
+ # Use the full accumulated value (don't split by comma - treat as single entry)
+ accumulated = self.accumulated_value.lower()
+ current_genre_prefix = accumulated.strip()
+
+ # Determine which trie to use: caption-matched (priority) or full vocab (fallback)
+ use_caption_trie = False
+ current_node = None
+
+ # Try caption-matched trie first if available
+ if self.caption_genres_trie:
+ if current_genre_prefix == "":
+ current_node = self.caption_genres_trie
+ use_caption_trie = True
+ else:
+ current_node = self._get_trie_node_from_trie(self.caption_genres_trie, current_genre_prefix)
+ if current_node is not None:
+ use_caption_trie = True
+
+ # Fallback to full vocab trie
+ if current_node is None:
+ if current_genre_prefix == "":
+ current_node = self.genres_trie
+ else:
+ current_node = self._get_genres_trie_node(current_genre_prefix)
+
+ if current_node is None:
+ # Invalid prefix, force newline to end
+ if self.newline_token:
+ return [self.newline_token]
+ return []
+
+ # Get valid next characters from trie node
+ valid_next_chars = set(k for k in current_node.keys() if k not in ('_end', '_tokens'))
+
+ # If current value is a complete genre, allow newline to end
+ is_complete = current_node.get('_end', False)
+
+ if not valid_next_chars:
+ # No more characters to match, only allow newline if complete
+ allowed = set()
+ if is_complete and self.newline_token:
+ allowed.add(self.newline_token)
+ return list(allowed)
+
+ # Collect candidate tokens based on first character
+ candidate_tokens = set()
+ for char in valid_next_chars:
+ if char in self._char_to_tokens:
+ candidate_tokens.update(self._char_to_tokens[char])
+
+ # Select the appropriate trie for validation
+ active_trie = self.caption_genres_trie if use_caption_trie else self.genres_trie
+
+ # Validate each candidate token: check if prefix + decoded_token is a valid trie prefix
+ allowed = set()
+ for token_id in candidate_tokens:
+ # Use precomputed decoded text (already normalized)
+ decoded_normalized = self._token_to_text.get(token_id, "")
+
+ if not decoded_normalized or not decoded_normalized.strip():
+ # Token decodes to empty or only whitespace - allow if space/comma is a valid next char
+ if ' ' in valid_next_chars or ',' in valid_next_chars:
+ allowed.add(token_id)
+ continue
+
+ # Build new prefix by appending decoded token
+ # Handle space-prefixed tokens (e.g., " rock" from "pop rock")
+ if decoded_normalized.startswith(' ') or decoded_normalized.startswith(','):
+ # Token has leading space/comma - append directly
+ new_prefix = current_genre_prefix + decoded_normalized
+ else:
+ new_prefix = current_genre_prefix + decoded_normalized
+
+ # Check if new_prefix is a valid prefix in the active trie
+ new_node = self._get_trie_node_from_trie(active_trie, new_prefix)
+ if new_node is not None:
+ allowed.add(token_id)
+
+ # If current value is a complete genre, also allow newline
+ if is_complete and self.newline_token:
+ allowed.add(self.newline_token)
+
+ return list(allowed)
+
+ def reset(self):
+ """Reset the processor state for a new generation."""
+ self.state = FSMState.THINK_TAG
+ self.position_in_state = 0
+ self.accumulated_value = "" # Legacy, kept for compatibility
+ self.accumulated_token_ids = [] # Reset token ID sequence
+ self.codes_count = 0 # Reset codes counter
+ self.user_field_token_queue = [] # Reset user field token queue
+ self.current_user_field = None # Reset current user field
+ self.caption_after_newline = False # Reset caption newline tracking
+ self.caption_token_count = 0 # Reset caption token count
+ self.caption_ending = False # Reset caption ending tracking
+ self.pending_field_name = "" # Reset pending field name
+
+ def set_target_duration(self, duration: Optional[float]):
+ """
+ Set the target duration for codes generation.
+
+ Args:
+ duration: Target duration in seconds. If None, no duration constraint is applied.
+ 5 codes = 1 second, so target_codes = duration * 5.
+ """
+ self.target_duration = duration
+ if duration is not None and duration > 0:
+ self.target_codes = int(duration * 5)
+ if self.debug:
+ logger.debug(f"Set target duration: {duration}s -> {self.target_codes} codes")
+ else:
+ self.target_codes = None
+ if self.debug:
+ logger.debug("Target duration cleared, no duration constraint")
+
+ def _get_allowed_tokens_for_fixed_string(self, fixed_str: str) -> List[int]:
+ """
+ Get the token IDs that can continue the fixed string from current position.
+ Returns list of allowed token IDs.
+
+ Strategy: Find the longest prefix that encodes to a single token, and return that token.
+ This ensures we generate by tokens, not character-by-character.
+ """
+ remaining = fixed_str[self.position_in_state:]
+ if not remaining:
+ return []
+
+ if self.debug:
+ logger.debug(f"_get_allowed_tokens_for_fixed_string: fixed_str={repr(fixed_str)}, position_in_state={self.position_in_state}, remaining={repr(remaining)}")
+
+ # Try encoding progressively longer prefixes, from longest to shortest
+ # We want to find the longest prefix that encodes to a single token
+ best_token = None
+ best_prefix_len = 0
+
+ # First pass: find the longest prefix that encodes to exactly one token
+ for end in range(len(remaining), 0, -1): # Start from longest prefix
+ prefix = remaining[:end]
+ tokens = self.tokenizer.encode(prefix, add_special_tokens=False)
+ if tokens and len(tokens) == 1:
+ # Found a prefix that encodes to a single token
+ # Use this one (longest match)
+ best_token = tokens[0]
+ best_prefix_len = end
+ if self.debug:
+ logger.debug(f"Found single-token match: prefix={repr(prefix)}, token_id={best_token}, token_text={repr(self.tokenizer.decode([best_token]))}")
+ break
+
+ # If we found a single-token match, return it (this is the preferred case)
+ if best_token is not None:
+ return [best_token]
+
+ # Fallback: if no single-token match found, collect all possible first tokens
+ # This handles edge cases where the string might need multiple tokens
+ # But we still want to prefer longer matches
+ # IMPORTANT: Only consider tokens that actually match the beginning of remaining string
+ # Decode each candidate token and verify it matches the prefix
+ allowed_tokens = {}
+ for end in range(1, min(len(remaining) + 1, 20)): # Limit search to avoid too many iterations
+ prefix = remaining[:end]
+ tokens = self.tokenizer.encode(prefix, add_special_tokens=False)
+ if tokens:
+ first_token = tokens[0]
+ # Verify: decode the token and check it matches the prefix start
+ decoded_token = self.tokenizer.decode([first_token])
+ # Normalize both for comparison (strip and lower)
+ normalized_prefix = prefix.lstrip().lower()
+ normalized_decoded = decoded_token.lstrip().lower()
+
+ # Check if decoded token matches the prefix start (allowing for space prefixes)
+ if normalized_decoded.startswith(normalized_prefix) or normalized_prefix.startswith(normalized_decoded):
+ # Store the longest prefix length for each token
+ if first_token not in allowed_tokens or end > allowed_tokens[first_token]:
+ allowed_tokens[first_token] = end
+
+ # Return tokens sorted by prefix length (longest first)
+ # This ensures we prefer longer matches
+ sorted_tokens = sorted(allowed_tokens.items(), key=lambda x: x[1], reverse=True)
+ result = [token for token, _ in sorted_tokens] if sorted_tokens else []
+
+ if self.debug:
+ logger.debug(f"Fallback: returning {len(result)} tokens: {[(t, repr(self.tokenizer.decode([t]))) for t in result[:5]]}")
+ if result:
+ logger.debug(f"Fixed string: {repr(fixed_str)}, position: {self.position_in_state}, remaining: {repr(remaining)}")
+
+ return result
+
+ def _get_allowed_digit_tokens(self, min_val: int, max_val: int) -> List[int]:
+ """
+ Get allowed digit tokens based on accumulated value and range constraints.
+ Uses early-blocking to prevent out-of-range values.
+ """
+ if not self.accumulated_value:
+ # First digit: determine valid starting digits
+ allowed_digits = set()
+ for v in range(min_val, max_val + 1):
+ allowed_digits.add(int(str(v)[0]))
+ return [self.digit_tokens[d] for d in allowed_digits if d in self.digit_tokens]
+
+ current = int(self.accumulated_value)
+ allowed = []
+
+ for d in range(10):
+ new_value = int(self.accumulated_value + str(d))
+ # Check if this digit could lead to a valid final value
+ # A digit is valid if:
+ # 1. new_value <= max_val (not already exceeded)
+ # 2. new_value could potentially reach >= min_val
+ # (i.e., new_value * 10^k >= min_val for some k >= 0)
+
+ if new_value > max_val:
+ continue # Already exceeded max
+
+ # Check if we can still reach min_val
+ # If new_value is already >= min_val, it's valid
+ # If new_value < min_val, we need more digits, but new_value * 10 must not exceed max
+ if new_value >= min_val:
+ allowed.append(d)
+ elif new_value * 10 <= max_val:
+ # Can add more digits
+ allowed.append(d)
+
+ return [self.digit_tokens[d] for d in allowed if d in self.digit_tokens]
+
+ def _get_allowed_numeric_tokens(self, prefix_tree: Dict[Tuple[int, ...], Set[int]]) -> List[int]:
+ """
+ Get allowed tokens for numeric field using the precomputed prefix tree.
+
+ IMPORTANT: Uses token ID sequence as key (not string) to avoid tokenization mismatches.
+
+ Args:
+ prefix_tree: Precomputed prefix tree mapping token ID sequence -> set of allowed token IDs
+
+ Returns:
+ List of allowed token IDs for current accumulated_token_ids
+ """
+ token_prefix = tuple(self.accumulated_token_ids)
+
+ if token_prefix in prefix_tree:
+ return list(prefix_tree[token_prefix])
+
+ # No valid continuation found - return empty list
+ # The caller will handle this by forcing newline to end the field
+ return []
+
+ def _should_end_numeric_field(self, logits: torch.Tensor, min_val: int, max_val: int) -> bool:
+ """
+ Determine if we should end the current numeric field.
+ Returns True if P(newline) > P(any valid digit) AND current value is valid.
+ """
+ if not self.accumulated_value:
+ return False
+
+ current = int(self.accumulated_value)
+ if current < min_val or current > max_val:
+ return False # Can't end yet, value not in range
+
+ # Get probabilities
+ probs = torch.softmax(logits, dim=-1)
+
+ newline_prob = probs[0, self.newline_token].item() if self.newline_token else 0
+
+ # Get max probability among valid digit tokens
+ allowed_digits = self._get_allowed_digit_tokens(min_val, max_val)
+ if not allowed_digits:
+ return True # No more digits possible, must end
+
+ max_digit_prob = max(probs[0, t].item() for t in allowed_digits)
+
+ if self.debug:
+ logger.debug(f"Numeric field decision: newline_prob={newline_prob:.4f}, max_digit_prob={max_digit_prob:.4f}")
+
+ return newline_prob > max_digit_prob
+
+
+ def _should_end_text_field(self, logits: torch.Tensor) -> bool:
+ """
+ Determine if we should end a text field (genres).
+ Returns True if P(newline) > P(any other token) AND we have some content.
+ """
+ if not self.accumulated_value.strip():
+ return False # Need at least some content
+
+ probs = torch.softmax(logits, dim=-1)
+ newline_prob = probs[0, self.newline_token].item() if self.newline_token else 0
+
+ # Get max probability among non-newline tokens
+ masked_probs = probs.clone()
+ if self.newline_token:
+ masked_probs[0, self.newline_token] = 0
+ max_other_prob = masked_probs[0].max().item()
+
+ return newline_prob > max_other_prob
+
+ def _get_allowed_keyscale_tokens(self) -> List[int]:
+ """
+ Get allowed tokens for keyscale field using the precomputed prefix tree.
+ Uses token ID sequence as key (not string) to avoid tokenization mismatches.
+ """
+ # Use token ID sequence as key
+ token_prefix = tuple(self.accumulated_token_ids)
+
+ if token_prefix in self.keyscale_prefix_tree:
+ return list(self.keyscale_prefix_tree[token_prefix])
+
+ # Fallback: if we somehow drifted off (shouldn't happen with constrained decoding),
+ # return empty to force newline logic or stop.
+ return []
+
+ def _is_keyscale_complete(self) -> bool:
+ """
+ Check if keyscale value is complete and valid.
+ Uses token ID sequence to check if current prefix allows newline.
+ """
+ token_prefix = tuple(self.accumulated_token_ids)
+ # If current token sequence prefix is in tree and allows newline, it's complete
+ if token_prefix in self.keyscale_prefix_tree:
+ return self.newline_token in self.keyscale_prefix_tree[token_prefix]
+ return False
+
+ def _get_allowed_language_tokens(self) -> List[int]:
+ """
+ Get allowed tokens for language field using the precomputed prefix tree.
+ Uses token ID sequence as key (not string) to avoid tokenization mismatches.
+ Similar to keyscale.
+ """
+ token_prefix = tuple(self.accumulated_token_ids)
+
+ if token_prefix in self.language_prefix_tree:
+ return list(self.language_prefix_tree[token_prefix])
+
+ # Fallback: no valid continuation found
+ return []
+
+ def _get_allowed_timesig_tokens(self) -> List[int]:
+ """
+ Get allowed tokens for timesignature field using the precomputed prefix tree.
+ Uses token ID sequence as key (not string) to avoid tokenization mismatches.
+ """
+ token_prefix = tuple(self.accumulated_token_ids)
+
+ if token_prefix in self.timesig_prefix_tree:
+ return list(self.timesig_prefix_tree[token_prefix])
+
+ # No valid continuation found - return empty list
+ # The caller will handle this by forcing newline to end the field
+ return []
+
+ def __call__(
+ self,
+ input_ids: torch.LongTensor,
+ scores: torch.FloatTensor,
+ ) -> torch.FloatTensor:
+ """
+ Apply constrained decoding by modifying logits.
+
+ Args:
+ input_ids: [batch_size, seq_len] input token IDs
+ scores: [batch_size, vocab_size] logits for next token
+
+ Returns:
+ Modified scores with invalid tokens masked to -inf and temperature scaling applied
+ """
+ if not self.enabled:
+ return self._apply_temperature_scaling(scores)
+
+ if self.state == FSMState.COMPLETED:
+ # In understanding phase, block audio codes during lyrics generation (COMPLETED state)
+ if self.generation_phase == "understand" and self.audio_code_mask is not None:
+ # Move mask to same device/dtype as scores if needed
+ if self.audio_code_mask.device != scores.device or self.audio_code_mask.dtype != scores.dtype:
+ self.audio_code_mask = self.audio_code_mask.to(device=scores.device, dtype=scores.dtype)
+ scores = scores + self.audio_code_mask
+ return self._apply_temperature_scaling(scores)
+
+ # For codes phase, detect if input already contains and skip to CODES_GENERATION
+ if self.generation_phase == "codes" and self.state == FSMState.THINK_TAG:
+ # Check if input contains token sequence
+ if self._input_contains_think_end_tag(input_ids):
+ # Skip metadata generation, go directly to codes generation
+ self.state = FSMState.CODES_GENERATION
+ self.codes_count = 0
+ if self.debug:
+ logger.debug("Codes phase: detected in input, skipping to CODES_GENERATION")
+
+ if self.state == FSMState.CODES_GENERATION:
+ # Block all non-audio-code tokens (only allow audio codes and EOS)
+ if self.non_audio_code_mask is not None:
+ # Move mask to same device/dtype as scores if needed
+ if self.non_audio_code_mask.device != scores.device or self.non_audio_code_mask.dtype != scores.dtype:
+ self.non_audio_code_mask = self.non_audio_code_mask.to(device=scores.device, dtype=scores.dtype)
+ scores = scores + self.non_audio_code_mask
+
+ # Apply duration constraint in codes generation phase
+ if self.target_codes is not None and self.eos_token_id is not None:
+ if self.codes_count < self.target_codes:
+ # Block EOS token until target codes count is reached
+ scores[:, self.eos_token_id] = float('-inf')
+ if self.debug:
+ logger.debug(f"Codes generation: {self.codes_count}/{self.target_codes}, blocking EOS")
+ else:
+ # Force EOS token when target codes count is reached - inplace
+ eos_scores = scores[:, self.eos_token_id].clone()
+ scores.fill_(float('-inf'))
+ scores[:, self.eos_token_id] = eos_scores
+ if self.debug:
+ logger.debug(f"Codes generation: {self.codes_count}/{self.target_codes}, forcing EOS")
+ return self._apply_temperature_scaling(scores)
+
+ batch_size = scores.shape[0]
+
+ # Process each sequence in batch
+ for b in range(batch_size):
+ result = self._process_single_sequence(input_ids[b], scores[b:b+1])
+ scores[b] = result[0] # result is [1, vocab_size], need [vocab_size]
+
+ # Apply temperature scaling after constraint masking
+ return self._apply_temperature_scaling(scores)
+
+ def _input_contains_think_end_tag(self, input_ids: torch.LongTensor) -> bool:
+ """
+ Check if input contains the closing tag.
+
+ Args:
+ input_ids: [batch_size, seq_len] input token IDs
+
+ Returns:
+ True if is found in the input (any sequence in batch)
+ """
+ # Tokenize to get its token sequence
+ think_end_tokens = self.tokenizer.encode("", add_special_tokens=False)
+ if not think_end_tokens:
+ return False
+
+ # Check each sequence in batch
+ for b in range(input_ids.shape[0]):
+ seq = input_ids[b].tolist()
+ # Search for the token sequence in the input
+ for i in range(len(seq) - len(think_end_tokens) + 1):
+ if seq[i:i+len(think_end_tokens)] == think_end_tokens:
+ return True
+
+ return False
+
+ def _apply_temperature_scaling(self, scores: torch.FloatTensor) -> torch.FloatTensor:
+ """
+ Apply temperature scaling based on current generation phase.
+
+ Temperature scaling: logits = logits / temperature
+ - Lower temperature (< 1.0) makes distribution sharper (more deterministic)
+ - Higher temperature (> 1.0) makes distribution flatter (more diverse)
+
+ Args:
+ scores: [batch_size, vocab_size] logits
+
+ Returns:
+ Temperature-scaled logits
+ """
+ # Determine which temperature to use based on current state
+ if self.state == FSMState.CODES_GENERATION or self.state == FSMState.COMPLETED:
+ temperature = self.codes_temperature
+ else:
+ temperature = self.metadata_temperature
+
+ # If no temperature is set for this phase, return scores unchanged
+ if temperature is None:
+ return scores
+
+ # Avoid division by zero
+ if temperature <= 0:
+ temperature = 1e-6
+
+ # Apply temperature scaling
+ return scores / temperature
+
+ def _get_user_provided_field_tokens(self, field_name: str) -> Optional[List[int]]:
+ """
+ Get token sequence for a user-provided field (field_name + value + newline).
+ Uses the same tokenization logic as prefix tree building.
+
+ Args:
+ field_name: Field name ("bpm", "caption", "duration", "keyscale", "language", "timesignature")
+
+ Returns:
+ List of token IDs for the complete field, or None if field is not provided
+ """
+ value = self.user_provided_metadata.get(field_name)
+ if value is None:
+ return None
+
+ # Build full field string with space (matching prefix tree tokenization)
+ field_to_prefix = {
+ "bpm": "bpm: ",
+ "caption": "caption: ",
+ "duration": "duration: ",
+ "keyscale": "keyscale: ",
+ "language": "language: ",
+ "timesignature": "timesignature: ",
+ "genres": "genres: ",
+ }
+ prefix = field_to_prefix[field_name]
+ full_text = f"{prefix}{value}\n"
+
+ # Tokenize the full field
+ tokens = self.tokenizer.encode(full_text, add_special_tokens=False)
+
+ # Extract only the field tokens (skip the prefix tokens that match state machine output)
+ # The state machine generates "field_name:" (no space), so we need to match that
+ prefix_for_matching = field_name + ":"
+ prefix_tokens = self.tokenizer.encode(prefix_for_matching, add_special_tokens=False)
+
+ # Find where prefix ends in full tokens
+ if len(tokens) >= len(prefix_tokens) and tokens[:len(prefix_tokens)] == prefix_tokens:
+ # Return tokens after prefix (field value + newline)
+ return tokens[len(prefix_tokens):]
+ else:
+ # Fallback: return all tokens (shouldn't happen if tokenization is consistent)
+ if self.debug:
+ logger.warning(f"Could not match prefix tokens for field {field_name}, using all tokens")
+ return tokens
+
+ def _process_single_sequence(
+ self,
+ input_ids: torch.LongTensor,
+ scores: torch.FloatTensor,
+ ) -> torch.FloatTensor:
+ """Process a single sequence and return modified scores (inplace when possible)."""
+
+ # Check if we have tokens in queue for user-provided field
+ # If so, inject the next token directly
+ if self.user_field_token_queue:
+ next_token = self.user_field_token_queue[0]
+ self._apply_whitelist_inplace(scores, [next_token])
+ return scores
+
+ if self.state in self.fixed_strings:
+ # Fixed string state: force specific tokens
+ fixed_str = self.fixed_strings[self.state]
+ allowed = self._get_allowed_tokens_for_fixed_string(fixed_str)
+
+ if allowed:
+ # Check if we should stop at reasoning (after tag)
+ # This happens when we're about to complete the tag
+ if self.state == FSMState.THINK_END_TAG and self.stop_at_reasoning:
+ # Check if the next token would complete the fixed string
+ remaining_chars = len(fixed_str) - self.position_in_state
+ # If remaining is small (<= 10 chars, which is typically 1-2 tokens), force EOS
+ if remaining_chars <= 10:
+ # Force EOS token to stop generation
+ if self.eos_token_id is not None:
+ self._apply_whitelist_inplace(scores, [self.eos_token_id])
+ if self.debug:
+ logger.debug(f"stop_at_reasoning=True: forcing EOS near end of tag (remaining: {remaining_chars} chars)")
+ return scores
+
+ # Apply whitelist constraint inplace
+ self._apply_whitelist_inplace(scores, allowed)
+ else:
+ # Position exceeds string, move to next state
+ # If stop_at_reasoning is True and we're transitioning from THINK_END_TAG,
+ # force EOS before transitioning
+ if self.state == FSMState.THINK_END_TAG and self.stop_at_reasoning:
+ # Force EOS token to stop generation
+ if self.eos_token_id is not None:
+ self._apply_whitelist_inplace(scores, [self.eos_token_id])
+ if self.debug:
+ logger.debug(f"stop_at_reasoning=True: forcing EOS after completing tag")
+ return scores
+
+ old_state = self.state
+ self._transition_to_next_state()
+ # Avoid infinite recursion: if we're still in a fixed_strings state, just return scores
+ if self.state in self.fixed_strings:
+ # This shouldn't happen, but if it does, just return scores to avoid recursion
+ if self.debug:
+ logger.warning(f"State transition from {old_state.name} to {self.state.name} still in fixed_strings, avoiding recursion")
+ return scores
+ # For recursion, reset scores to zero (no constraints from previous state)
+ scores.zero_()
+ return self._process_single_sequence(input_ids, scores)
+
+ elif self.state == FSMState.BPM_VALUE:
+ # Check if field is user-provided and we haven't started injecting yet
+ if self.user_provided_metadata["bpm"] is not None and not self.user_field_token_queue and not self.accumulated_token_ids:
+ # Initialize token queue with field value tokens (value + newline)
+ value = self.user_provided_metadata["bpm"]
+ # Tokenize " value\n" (space + value + newline) to match actual tokenization
+ value_text = f" {value}\n"
+ value_tokens = self.tokenizer.encode(value_text, add_special_tokens=False)
+ if value_tokens:
+ self.user_field_token_queue = value_tokens
+ self.current_user_field = "bpm"
+ # Inject first token
+ self._apply_whitelist_inplace(scores, [value_tokens[0]])
+ return scores
+
+ # Allow valid numeric tokens using prefix tree (supports multi-digit tokens like "120")
+ allowed = self._get_allowed_numeric_tokens(self.bpm_prefix_tree)
+
+ # Also allow newline if current token sequence prefix allows it
+ token_prefix = tuple(self.accumulated_token_ids)
+ if token_prefix in self.bpm_prefix_tree and self.newline_token in self.bpm_prefix_tree[token_prefix]:
+ allowed = allowed + [self.newline_token]
+
+ self._apply_whitelist_inplace(scores, allowed)
+
+ elif self.state == FSMState.CAPTION_VALUE:
+ # Caption field generation with YAML format support:
+ # - Allow newlines and spaces (YAML multi-line formatting)
+ # - Block audio codes and backticks
+ # - Max 512 tokens
+ # - Transition when model wants to generate next field (non-indented line)
+
+ # Check if field is user-provided and we haven't started injecting yet
+ if self.user_provided_metadata["caption"] is not None and not self.user_field_token_queue and not self.accumulated_value:
+ # Initialize token queue with field value tokens (value + newline)
+ value = self.user_provided_metadata["caption"]
+ value_text = f" {value}\n"
+ value_tokens = self.tokenizer.encode(value_text, add_special_tokens=False)
+ if value_tokens:
+ self.user_field_token_queue = value_tokens
+ self.current_user_field = "caption"
+ # Inject first token
+ self._apply_whitelist_inplace(scores, [value_tokens[0]])
+ return scores
+
+ # Check if we should transition after a newline (non-indented line = new field)
+ if self.caption_after_newline:
+ # Get top token from current scores
+ top_token_id = torch.argmax(scores[0]).item()
+ top_token_text = self.tokenizer.decode([top_token_id])
+
+ # If top token does NOT start with space/tab, it's a new field (like "duration:")
+ if len(top_token_text) > 0 and top_token_text[0] not in ' \t':
+ # Caption is ending - LM is generating next field name
+ # Instead of forcing state transition to DURATION_NAME (which would regenerate the field name),
+ # we enter a "caption_ending" mode where we allow free generation until we detect the field value
+ self.caption_after_newline = False
+ self.caption_ending = True # Start tracking field name
+ self.pending_field_name = "" # Reset pending field name
+ # Allow free generation (no constraints) so LM can generate field name naturally
+ return scores
+ else:
+ # It's indentation, continue caption (don't transition!)
+ self.caption_after_newline = False
+ # Continue normal caption generation
+ # Fall through to caption constraints below
+
+ # If caption is ending (LM generating next field name), allow free generation
+ # and track the field name until we see colon
+ if self.caption_ending:
+ # Allow any token (free generation)
+ # The field name detection will happen in update_state()
+ return scores
+
+ # Block backticks (code blocks) - inplace
+ if self.backtick_token is not None:
+ scores[0, self.backtick_token] = float('-inf')
+
+ # Block ALL audio code tokens (critical - these should never appear in caption)
+ # Use precomputed mask for O(1) performance instead of O(n) loop
+ if self.audio_code_mask is not None:
+ # Move mask to same device/dtype as scores if needed
+ if self.audio_code_mask.device != scores.device or self.audio_code_mask.dtype != scores.dtype:
+ self.audio_code_mask = self.audio_code_mask.to(device=scores.device, dtype=scores.dtype)
+ scores = scores + self.audio_code_mask
+
+ # Enforce 512 token limit for caption
+ if self.caption_token_count >= 512:
+ # Force end by only allowing newline
+ if self.newline_token is not None:
+ self._apply_whitelist_inplace(scores, [self.newline_token])
+ return scores
+
+ # Allow natural generation (with blocked audio codes and backticks)
+ return scores
+
+ elif self.state == FSMState.DURATION_VALUE:
+ # Check if field is user-provided and we haven't started injecting yet
+ if self.user_provided_metadata["duration"] is not None and not self.user_field_token_queue and not self.accumulated_token_ids:
+ # Initialize token queue with field value tokens (value + newline)
+ value = self.user_provided_metadata["duration"]
+ value_text = f" {value}\n"
+ value_tokens = self.tokenizer.encode(value_text, add_special_tokens=False)
+ if value_tokens:
+ self.user_field_token_queue = value_tokens
+ self.current_user_field = "duration"
+ # Inject first token
+ self._apply_whitelist_inplace(scores, [value_tokens[0]])
+ return scores
+
+ # If target_duration is set, force generate that exact value
+ if self.target_duration is not None:
+ target_str = str(int(self.target_duration))
+ current_pos = len(self.accumulated_value)
+
+ if current_pos < len(target_str):
+ # Force the next digit
+ next_digit = int(target_str[current_pos])
+ if next_digit in self.digit_tokens:
+ self._apply_whitelist_inplace(scores, [self.digit_tokens[next_digit]])
+ else:
+ # All digits generated, force newline
+ if self.newline_token:
+ self._apply_whitelist_inplace(scores, [self.newline_token])
+ else:
+ # Normal duration generation with range constraint
+ # Allow valid numeric tokens using prefix tree (supports multi-digit tokens like "60", "120")
+ allowed = self._get_allowed_numeric_tokens(self.duration_prefix_tree)
+
+ # Also allow newline if current token sequence prefix allows it
+ token_prefix = tuple(self.accumulated_token_ids)
+ if token_prefix in self.duration_prefix_tree and self.newline_token in self.duration_prefix_tree[token_prefix]:
+ allowed = allowed + [self.newline_token]
+
+ self._apply_whitelist_inplace(scores, allowed)
+
+ elif self.state == FSMState.GENRES_VALUE:
+ # Check if field is user-provided and we haven't started injecting yet
+ if self.user_provided_metadata["genres"] is not None and not self.user_field_token_queue and not self.accumulated_value:
+ # Initialize token queue with field value tokens (value + newline)
+ value = self.user_provided_metadata["genres"]
+ value_text = f" {value}\n"
+ value_tokens = self.tokenizer.encode(value_text, add_special_tokens=False)
+ if value_tokens:
+ self.user_field_token_queue = value_tokens
+ self.current_user_field = "genres"
+ # Inject first token
+ self._apply_whitelist_inplace(scores, [value_tokens[0]])
+ return scores
+
+ # Try to hot-reload genres vocab if file has changed
+ self._try_reload_genres_vocab()
+
+ # Get allowed tokens based on genres vocabulary
+ allowed = self._get_allowed_genres_tokens()
+
+ if allowed:
+ # Use vocabulary-constrained decoding
+ self._apply_whitelist_inplace(scores, allowed)
+ elif self.genres_vocab:
+ # Vocab is loaded but no valid continuation found
+ # Force newline to end the field
+ if self.newline_token:
+ if self.debug:
+ logger.debug(f"No valid genre continuation for '{self.accumulated_value}', forcing newline")
+ self._apply_whitelist_inplace(scores, [self.newline_token])
+ else:
+ # Fallback: no vocab loaded, use probability-based ending
+ if self._should_end_text_field(scores):
+ if self.newline_token:
+ self._apply_whitelist_inplace(scores, [self.newline_token])
+ self._transition_to_next_state()
+ else:
+ # Allow any token except newline if we don't have content yet
+ if not self.accumulated_value.strip():
+ if self.newline_token:
+ scores[0, self.newline_token] = float('-inf')
+ # Otherwise, don't constrain (fallback behavior)
+
+ elif self.state == FSMState.KEYSCALE_VALUE:
+ # Check if field is user-provided and we haven't started injecting yet
+ if self.user_provided_metadata["keyscale"] is not None and not self.user_field_token_queue and not self.accumulated_token_ids:
+ # Initialize token queue with field value tokens (value + newline)
+ value = self.user_provided_metadata["keyscale"]
+ value_text = f" {value}\n"
+ value_tokens = self.tokenizer.encode(value_text, add_special_tokens=False)
+ if value_tokens:
+ self.user_field_token_queue = value_tokens
+ self.current_user_field = "keyscale"
+ # Inject first token
+ self._apply_whitelist_inplace(scores, [value_tokens[0]])
+ return scores
+
+ # Check if current token sequence is complete (allows newline)
+ token_prefix = tuple(self.accumulated_token_ids)
+ if token_prefix in self.keyscale_prefix_tree and self.newline_token in self.keyscale_prefix_tree[token_prefix]:
+ # Complete keyscale, allow newline
+ if self.newline_token:
+ self._apply_whitelist_inplace(scores, [self.newline_token])
+ else:
+ # Not complete, allow valid continuation tokens
+ allowed = self._get_allowed_keyscale_tokens()
+ if allowed:
+ self._apply_whitelist_inplace(scores, allowed)
+ else:
+ # No valid tokens found - force newline to end field
+ # This handles edge cases where keyscale format is unexpected
+ if self.newline_token:
+ self._apply_whitelist_inplace(scores, [self.newline_token])
+
+ elif self.state == FSMState.LANGUAGE_VALUE:
+ # Language field: Use top-1 probability language (greedy selection)
+ # Unlike other fields, we don't use prefix tree sampling.
+ # Instead, we select the highest probability language at the start,
+ # then force generate the rest of that language code.
+
+ # Check if field is user-provided and we haven't started injecting yet
+ if self.user_provided_metadata["language"] is not None and not self.user_field_token_queue and not self.accumulated_token_ids:
+ # Initialize token queue with field value tokens (value + newline)
+ value = self.user_provided_metadata["language"]
+ value_text = f" {value}\n"
+ value_tokens = self.tokenizer.encode(value_text, add_special_tokens=False)
+ if value_tokens:
+ self.user_field_token_queue = value_tokens
+ self.current_user_field = "language"
+ # Inject first token
+ self._apply_whitelist_inplace(scores, [value_tokens[0]])
+ return scores
+
+ # If we haven't started generating language yet (empty accumulated_token_ids),
+ # select the top-1 probability token from all valid first tokens
+ if not self.accumulated_token_ids:
+ # Get all possible first tokens for all languages
+ empty_prefix = tuple()
+ if empty_prefix in self.language_prefix_tree:
+ candidate_tokens = list(self.language_prefix_tree[empty_prefix])
+
+ if candidate_tokens:
+ # Find the token with highest probability (top-1) among candidates
+ # Use tensor indexing to get scores of candidate tokens directly
+ candidate_indices = torch.tensor(candidate_tokens, device=scores.device, dtype=torch.long)
+ candidate_scores = scores[0, candidate_indices]
+
+ # Get the highest probability token among candidates
+ best_idx = torch.argmax(candidate_scores).item()
+ top_token_id = candidate_tokens[best_idx]
+
+ # Only allow this top-1 token, block all others
+ self._apply_whitelist_inplace(scores, [top_token_id])
+
+ if self.debug:
+ top_token_text = self.tokenizer.decode([top_token_id])
+ logger.debug(f"Language field: selected top-1 token {top_token_id} ({repr(top_token_text)}) from {len(candidate_tokens)} candidates")
+ else:
+ # No valid first tokens found - force newline
+ if self.newline_token:
+ self._apply_whitelist_inplace(scores, [self.newline_token])
+ else:
+ # Empty prefix not in tree - force newline
+ if self.newline_token:
+ self._apply_whitelist_inplace(scores, [self.newline_token])
+ else:
+ # We've started generating a language, continue with prefix tree constraints
+ # Check if current token sequence is complete (allows newline)
+ token_prefix = tuple(self.accumulated_token_ids)
+ if token_prefix in self.language_prefix_tree and self.newline_token in self.language_prefix_tree[token_prefix]:
+ # Complete language, allow newline
+ if self.newline_token:
+ self._apply_whitelist_inplace(scores, [self.newline_token])
+ else:
+ # Not complete, allow valid continuation tokens
+ allowed = self._get_allowed_language_tokens()
+ if allowed:
+ self._apply_whitelist_inplace(scores, allowed)
+ else:
+ # No valid tokens found - force newline to end field
+ if self.newline_token:
+ self._apply_whitelist_inplace(scores, [self.newline_token])
+
+ elif self.state == FSMState.TIMESIG_VALUE:
+ # Check if field is user-provided and we haven't started injecting yet
+ if self.user_provided_metadata["timesignature"] is not None and not self.user_field_token_queue and not self.accumulated_token_ids:
+ # Initialize token queue with field value tokens (value + newline)
+ value = self.user_provided_metadata["timesignature"]
+ value_text = f" {value}\n"
+ value_tokens = self.tokenizer.encode(value_text, add_special_tokens=False)
+ if value_tokens:
+ self.user_field_token_queue = value_tokens
+ self.current_user_field = "timesignature"
+ # Inject first token
+ self._apply_whitelist_inplace(scores, [value_tokens[0]])
+ return scores
+
+ # Check if current token sequence is complete (allows newline)
+ token_prefix = tuple(self.accumulated_token_ids)
+ if token_prefix in self.timesig_prefix_tree and self.newline_token in self.timesig_prefix_tree[token_prefix]:
+ # Complete value, allow newline
+ if self.newline_token:
+ self._apply_whitelist_inplace(scores, [self.newline_token])
+ else:
+ # Not complete, allow valid continuation tokens
+ allowed = self._get_allowed_timesig_tokens()
+ self._apply_whitelist_inplace(scores, allowed)
+
+ return scores
+
+ def _transition_to_next_state(self):
+ """Transition to the next FSM state."""
+ if self.state in self.next_state:
+ old_state = self.state
+ next_state = self.next_state[self.state]
+
+ # Handle different cases at THINK_END_TAG based on generation phase
+ # NOTE: Do NOT override next_state here when stop_at_reasoning=True
+ # because we need to transition to the tag state first to generate ,
+ # then handle stop_at_reasoning in update_state() AFTER the tag is complete
+ if old_state == FSMState.THINK_END_TAG:
+ if self.generation_phase == "understand":
+ # Understanding mode: allow free-form lyrics after metadata
+ # No more constrained decoding after
+ next_state = FSMState.COMPLETED
+ if self.debug:
+ logger.debug(f"generation_phase='understand': allowing free-form lyrics after ")
+ # else: default to CODES_GENERATION (for "codes" phase) or respect stop_at_reasoning flag
+
+ self.state = next_state
+ self.position_in_state = 0
+ self.accumulated_value = "" # Legacy, kept for compatibility
+ self.accumulated_token_ids = [] # Reset token ID sequence for new field
+ self.caption_after_newline = False # Reset caption newline tracking
+ self.caption_token_count = 0 # Reset caption token count
+ self.caption_ending = False # Reset caption ending tracking
+ self.pending_field_name = "" # Reset pending field name
+ if self.debug:
+ logger.debug(f"FSM transition: {old_state.name} -> {self.state.name}")
+
+ def update_state(self, generated_token_id: int):
+ """
+ Update internal state after a token has been generated.
+ This should be called after each token generation.
+
+ Args:
+ generated_token_id: The token ID that was just generated
+ """
+ if not self.enabled:
+ return
+
+ if self.state == FSMState.COMPLETED:
+ return
+
+ if self.state == FSMState.CODES_GENERATION:
+ # Count generated codes for duration constraint
+ self.codes_count += 1
+ if self.debug and self.target_codes is not None:
+ logger.debug(f"Codes count: {self.codes_count}/{self.target_codes}")
+ return
+
+ # Handle user-provided field token injection
+ if self.user_field_token_queue:
+ # Verify the generated token matches the expected token from queue
+ expected_token = self.user_field_token_queue[0]
+ if generated_token_id != expected_token:
+ if self.debug:
+ logger.warning(f"Expected token {expected_token} but got {generated_token_id} for user-provided field {self.current_user_field}")
+
+ # Remove consumed token from queue
+ self.user_field_token_queue.pop(0)
+
+ # If queue is empty, field injection is complete, transition to next state
+ if not self.user_field_token_queue:
+ if self.debug:
+ logger.debug(f"Completed injection of user-provided field: {self.current_user_field}")
+ field_name = self.current_user_field
+ self.current_user_field = None
+
+ # Transition to next state (skip VALUE state since we already injected everything)
+ # The next state should be determined by _get_next_field_state
+ next_state = self._get_next_field_state(field_name)
+ if next_state:
+ old_state = self.state
+ self.state = next_state
+ self.position_in_state = 0
+ self.accumulated_value = ""
+ self.accumulated_token_ids = []
+ if self.debug:
+ logger.debug(f"FSM transition (after user field injection): {old_state.name} -> {self.state.name}")
+ else:
+ # All fields done, go to THINK_END_TAG
+ self._transition_to_next_state()
+ return
+
+ token_str = self.tokenizer.decode([generated_token_id])
+
+ if self.debug:
+ logger.debug(f"Generated token: {repr(token_str)} (id={generated_token_id}), state={self.state.name}")
+
+ if self.state in self.fixed_strings:
+ # Update position in fixed string
+ fixed_str = self.fixed_strings[self.state]
+ self.position_in_state += len(token_str)
+
+ # Check if we've completed the fixed string
+ if self.position_in_state >= len(fixed_str):
+ # Special handling for THINK_END_TAG with stop_at_reasoning
+ if self.state == FSMState.THINK_END_TAG and self.stop_at_reasoning:
+ # tag is complete, now we can stop generation
+ # Force transition to COMPLETED instead of CODES_GENERATION
+ old_state = self.state
+ self.state = FSMState.COMPLETED
+ self.position_in_state = 0
+ self.accumulated_value = ""
+ self.accumulated_token_ids = []
+ if self.debug:
+ logger.debug(f"FSM transition (stop_at_reasoning): {old_state.name} -> {self.state.name}")
+ else:
+ self._transition_to_next_state()
+
+ elif self.state in [FSMState.BPM_VALUE, FSMState.DURATION_VALUE, FSMState.TIMESIG_VALUE]:
+ # Accumulate numeric value using token ID sequence
+ if generated_token_id == self.newline_token:
+ old_state = self.state
+ self._transition_to_next_state()
+ # IMPORTANT: After state transition, if new state is a fixed_strings state,
+ # we should NOT update position_in_state with the newline token length,
+ # because that token belongs to the old state, not the new state.
+ # Return early to avoid the fixed_strings update logic below.
+ if self.state in self.fixed_strings:
+ return
+ else:
+ # Add token ID to sequence (for prefix tree lookup)
+ self.accumulated_token_ids.append(generated_token_id)
+ # Also update legacy accumulated_value for compatibility
+ if token_str.strip().isdigit():
+ self.accumulated_value += token_str.strip()
+
+ elif self.state == FSMState.GENRES_VALUE:
+ if generated_token_id == self.newline_token:
+ # Newline ends the field
+ self._transition_to_next_state()
+ # IMPORTANT: After state transition, if new state is a fixed_strings state,
+ # we should NOT update position_in_state with the newline token length,
+ # because that token belongs to the old state, not the new state.
+ # Return early to avoid the fixed_strings update logic below.
+ if self.state in self.fixed_strings:
+ return
+ else:
+ # Genres still uses string-based trie, so keep accumulated_value
+ self.accumulated_value += token_str
+
+ elif self.state == FSMState.CAPTION_VALUE:
+ # Track token count for 512 limit
+ self.caption_token_count += 1
+
+ # Accumulate caption text
+ self.accumulated_value += token_str
+
+ # Track if this token contains a newline (for transition detection)
+ # Token may be '\n' alone or combined with other chars like '.\n'
+ if '\n' in token_str:
+ # Mark that we need to check next token for field transition
+ self.caption_after_newline = True
+ else:
+ # Not a newline - if we were after newline and this is not space,
+ # transition already happened in _process_single_sequence
+ self.caption_after_newline = False
+
+ # If caption is ending, accumulate field name and detect field completion
+ if self.caption_ending:
+ self.pending_field_name += token_str
+
+ # Check if we've completed a field name (detected colon)
+ if ':' in token_str or token_str.strip() == ':':
+ # Extract field name (before colon)
+ field_name_full = self.pending_field_name.strip()
+ # Remove trailing colon if present
+ field_name = field_name_full.rstrip(':').strip().lower()
+
+ if self.debug:
+ logger.debug(f"Detected field name after caption: {repr(field_name)}")
+
+ # Map field name to VALUE state
+ field_name_to_value_state = {
+ "duration": FSMState.DURATION_VALUE,
+ "genres": FSMState.GENRES_VALUE,
+ "keyscale": FSMState.KEYSCALE_VALUE,
+ "language": FSMState.LANGUAGE_VALUE,
+ "timesignature": FSMState.TIMESIG_VALUE,
+ }
+
+ if field_name in field_name_to_value_state:
+ # Transition directly to the field's VALUE state
+ old_state = self.state
+ self.state = field_name_to_value_state[field_name]
+ self.position_in_state = 0
+ self.accumulated_value = ""
+ self.accumulated_token_ids = []
+ self.caption_ending = False
+ self.pending_field_name = ""
+
+ if self.debug:
+ logger.debug(f"FSM transition (caption ending): {old_state.name} -> {self.state.name}")
+ else:
+ # Unknown field name, force transition to next field
+ if self.debug:
+ logger.warning(f"Unknown field name after caption: {repr(field_name)}, forcing transition")
+ self.caption_ending = False
+ self.pending_field_name = ""
+ self._transition_to_next_state()
+
+ elif self.state == FSMState.KEYSCALE_VALUE:
+ if generated_token_id == self.newline_token:
+ # Newline ends the field
+ self._transition_to_next_state()
+ # IMPORTANT: After state transition, if new state is a fixed_strings state,
+ # we should NOT update position_in_state with the newline token length,
+ # because that token belongs to the old state, not the new state.
+ # Return early to avoid the fixed_strings update logic below.
+ if self.state in self.fixed_strings:
+ return
+ else:
+ # Add token ID to sequence (for prefix tree lookup)
+ self.accumulated_token_ids.append(generated_token_id)
+ # Also update legacy accumulated_value for compatibility
+ self.accumulated_value += token_str
+
+ elif self.state == FSMState.LANGUAGE_VALUE:
+ if generated_token_id == self.newline_token:
+ # Newline ends the field
+ self._transition_to_next_state()
+ if self.state in self.fixed_strings:
+ return
+ else:
+ # Add token ID to sequence (for prefix tree lookup)
+ self.accumulated_token_ids.append(generated_token_id)
+ # Also update legacy accumulated_value for compatibility
+ self.accumulated_value += token_str
+