from enum import Enum, auto
from typing import Optional, Dict, Any, Tuple, List, Callable, Set
from loguru import logger
from transformers import AutoTokenizer
from transformers.generation.logits_process import LogitsProcessor
import os
import torch
from acestep.constants import (
VALID_LANGUAGES,
KEYSCALE_NOTES,
KEYSCALE_ACCIDENTALS,
KEYSCALE_MODES,
VALID_KEYSCALES,
BPM_MIN,
BPM_MAX,
DURATION_MIN,
DURATION_MAX,
VALID_TIME_SIGNATURES,
)
# Maximum valid audio code value (codebook size = 64000)
MAX_AUDIO_CODE = 63999
# ==============================================================================
# FSM States for Constrained Decoding
# ==============================================================================
class FSMState(Enum):
"""Finite State Machine states for metadata generation"""
THINK_TAG = auto() # Generating ""
NEWLINE_AFTER_THINK = auto() # Generating "\n" after
BPM_NAME = auto() # Generating "bpm: "
BPM_VALUE = auto() # Generating numeric value 30-300
NEWLINE_AFTER_BPM = auto() # Generating "\n" after bpm value
CAPTION_NAME = auto() # Generating "caption: "
CAPTION_VALUE = auto() # Generating caption text (no code blocks/newlines)
DURATION_NAME = auto() # Generating "duration: "
DURATION_VALUE = auto() # Generating numeric value 10-600
NEWLINE_AFTER_DURATION = auto()
GENRES_NAME = auto() # Generating "genres: "
GENRES_VALUE = auto() # Generating any non-empty string
NEWLINE_AFTER_GENRES = auto()
KEYSCALE_NAME = auto() # Generating "keyscale: "
KEYSCALE_VALUE = auto() # Generating keyscale pattern
NEWLINE_AFTER_KEYSCALE = auto()
LANGUAGE_NAME = auto() # Generating "language: "
LANGUAGE_VALUE = auto() # Generating language code (en, zh, ja, etc.)
TIMESIG_NAME = auto() # Generating "timesignature: "
TIMESIG_VALUE = auto() # Generating 2, 3, 4, or 6
NEWLINE_AFTER_TIMESIG = auto()
THINK_END_TAG = auto() # Generating ""
CODES_GENERATION = auto() # Generating audio codes (no constraints)
COMPLETED = auto() # Generation completed
class MetadataConstrainedLogitsProcessor(LogitsProcessor):
"""
FSM-driven LogitsProcessor that constrains generation to produce valid metadata.
This processor enforces the following format:
bpm: [30-300]
caption: [text without code blocks, ends with period + newline]
duration: [10-600]
keyscale: [A-G][#/♭]? [major/minor]
language: [en/zh/ja/ko/es/fr/de/uk/ru/...]
timesignature: [2/3/4/6]
It uses token masking (setting invalid token logits to -inf) to enforce constraints.
For numeric fields, it uses early-blocking to prevent out-of-range values.
For field transitions (e.g., end of numeric value), it compares P(newline) vs P(digit).
For caption field, it blocks code blocks and newlines, and only transitions when
the previous token was a period and newline has the highest probability.
"""
def __init__(
self,
tokenizer: AutoTokenizer,
enabled: bool = True,
debug: bool = False,
genres_vocab_path: Optional[str] = None,
skip_genres: bool = True,
):
"""
Initialize the constrained logits processor.
This processor should be initialized once when loading the LLM and reused
for all generations.
Args:
tokenizer: The tokenizer to use for encoding/decoding
enabled: Whether to enable constrained decoding
debug: Whether to print debug information
"""
self.tokenizer = tokenizer
self.enabled = enabled
self.debug = debug
self.skip_genres = skip_genres
self.skip_caption = False # Set to True to skip caption field generation
self.skip_language = False # Set to True to skip language field generation
self.caption: Optional[str] = None # Set via update_caption() before each generation
# User-provided metadata fields (optional)
# If provided, these fields will be used directly instead of generating
# Format: {"bpm": "120", "caption": "...", "duration": "234", "keyscale": "G major", "language": "en", "timesignature": "4"}
self.user_provided_metadata: Dict[str, Optional[str]] = {
"bpm": None,
"caption": None,
"duration": None,
"keyscale": None,
"language": None,
"timesignature": None,
"genres": None,
}
# Temperature settings for different generation phases (set per-generation)
# If set, the processor will apply temperature scaling (divide logits by temperature)
# Note: Set base sampler temperature to 1.0 when using processor-based temperature
self.metadata_temperature: Optional[float] = None
self.codes_temperature: Optional[float] = None
# Duration constraint for codes generation
# 5 codes = 1 second, so target_codes = target_duration * 5
self.target_duration: Optional[float] = None # User-specified duration in seconds
self.target_codes: Optional[int] = None # Computed target codes count
self.codes_count: int = 0 # Counter for generated codes
# Stop at reasoning flag - if True, stop generation after tag
self.stop_at_reasoning: bool = False
# Generation phase - "cot" or "codes"
# Used to determine FSM behavior when prompt already contains CoT
self.generation_phase: str = "cot"
# Current state
self.state = FSMState.THINK_TAG
self.position_in_state = 0 # Position within current state's fixed string
self.accumulated_value = "" # For numeric/text value accumulation (legacy, for compatibility)
self.accumulated_token_ids: List[int] = [] # Token ID sequence for keyscale (and other fields)
# Caption generation state tracking
self.caption_after_newline = False # Track if we're right after a newline in caption
self.caption_token_count = 0 # Track token count for caption (max 512)
self.caption_ending = False # Track if caption is ending (after detecting non-indented line)
self.pending_field_name = "" # Accumulate field name tokens when caption is ending
# Token queue for user-provided fields (injected directly without generation)
self.user_field_token_queue: List[int] = []
self.current_user_field: Optional[str] = None # Current field being injected
# Pre-compute token IDs for efficiency
self._precompute_tokens()
# Genres vocabulary for constrained decoding
self.genres_vocab_path = genres_vocab_path or os.path.join(
os.path.dirname(os.path.abspath(__file__)), "genres_vocab.txt"
)
self.genres_vocab: List[str] = [] # Full vocab
self.genres_vocab_mtime: float = 0.0
self.genres_trie: Dict = {} # Trie for full vocab (fallback)
self.caption_genres_trie: Dict = {} # Trie for caption-matched genres (priority)
self.caption_matched_genres: List[str] = [] # Genres matched from caption
self._char_to_tokens: Dict[str, set] = {} # Precomputed char -> token IDs mapping
# Precompute token mappings once (O(vocab_size), runs once at init)
self._precompute_char_token_mapping()
# Field definitions (needed before building prefix trees)
self.field_specs = {
"bpm": {"min": BPM_MIN, "max": BPM_MAX},
"duration": {"min": DURATION_MIN, "max": DURATION_MAX},
"timesignature": {"valid_values": VALID_TIME_SIGNATURES},
}
# Build valid numeric values for BPM, Duration, Timesignature
# These will be used to build prefix trees based on actual tokenization
self.valid_bpm_values = [str(v) for v in range(self.field_specs["bpm"]["min"], self.field_specs["bpm"]["max"] + 1)]
self.valid_duration_values = [str(v) for v in range(self.field_specs["duration"]["min"], self.field_specs["duration"]["max"] + 1)]
self.valid_timesig_values = [str(v) for v in self.field_specs["timesignature"]["valid_values"]]
# Build keyscale prefix tree (requires _char_to_tokens to be initialized)
self.keyscale_prefix_tree = self._build_keyscale_prefix_tree()
# Build numeric prefix trees (BPM, Duration, Timesignature) with context
# IMPORTANT: State machine generates "bpm:" (no space), but tokenizer sees "bpm: " (with space)
# Use same logic as keyscale: context_prefix_for_matching (no space) and context_prefix_for_tokenization (with space)
self.bpm_prefix_tree = self._build_numeric_prefix_tree(
self.valid_bpm_values,
context_prefix_for_matching="bpm:",
context_prefix_for_tokenization="bpm: "
)
self.duration_prefix_tree = self._build_numeric_prefix_tree(
self.valid_duration_values,
context_prefix_for_matching="duration:",
context_prefix_for_tokenization="duration: "
)
self.timesig_prefix_tree = self._build_numeric_prefix_tree(
self.valid_timesig_values,
context_prefix_for_matching="timesignature:",
context_prefix_for_tokenization="timesignature: "
)
# Build language prefix tree (similar to keyscale but for language codes)
self.language_prefix_tree = self._build_language_prefix_tree()
self._load_genres_vocab()
# Fixed strings for each state
# IMPORTANT: Do NOT include trailing space after colon - tokenizer will handle spacing
# All matching should be done at token level, not string level
# NOTE: NEWLINE_AFTER_* states are removed - field values generate newline directly and transition to next field
self.fixed_strings = {
FSMState.THINK_TAG: "",
FSMState.NEWLINE_AFTER_THINK: "\n",
FSMState.BPM_NAME: "bpm:",
FSMState.CAPTION_NAME: "caption:",
FSMState.DURATION_NAME: "duration:",
FSMState.GENRES_NAME: "genres:",
FSMState.KEYSCALE_NAME: "keyscale:",
FSMState.LANGUAGE_NAME: "language:",
FSMState.TIMESIG_NAME: "timesignature:",
FSMState.THINK_END_TAG: "",
}
# State transitions
self._build_state_transitions()
def _get_next_field_state(self, current_field: str) -> Optional[FSMState]:
"""
Get the next field state. Always returns the next field's NAME state,
even if the field is user-provided (we still need to generate the field name).
Args:
current_field: Current field name ("bpm", "caption", "duration", "genres", "keyscale", "language", "timesignature")
Returns:
Next FSMState (NAME state of next field), or THINK_END_TAG if no more fields
"""
# New field order: bpm -> caption -> duration -> keyscale -> language -> timesignature
# genres is optional and can be skipped
field_order = ["bpm", "caption", "duration", "genres", "keyscale", "language", "timesignature"]
field_to_state = {
"bpm": FSMState.BPM_NAME,
"caption": FSMState.CAPTION_NAME,
"duration": FSMState.DURATION_NAME,
"genres": FSMState.GENRES_NAME,
"keyscale": FSMState.KEYSCALE_NAME,
"language": FSMState.LANGUAGE_NAME,
"timesignature": FSMState.TIMESIG_NAME,
}
try:
current_idx = field_order.index(current_field)
except ValueError:
return FSMState.THINK_END_TAG
# Find next field in order
for i in range(current_idx + 1, len(field_order)):
field = field_order[i]
# Skip fields based on flags
if field == "genres" and self.skip_genres:
continue
if field == "caption" and self.skip_caption:
continue
if field == "language" and self.skip_language:
continue
# Return the next field's NAME state (even if user-provided, we still generate field name)
return field_to_state[field]
# No more fields, go to THINK_END_TAG
return FSMState.THINK_END_TAG
def _build_state_transitions(self):
"""Build state transition map based on user-provided metadata."""
self.next_state = {
FSMState.THINK_TAG: FSMState.NEWLINE_AFTER_THINK,
FSMState.NEWLINE_AFTER_THINK: FSMState.BPM_NAME, # Always start with BPM
FSMState.THINK_END_TAG: FSMState.CODES_GENERATION,
FSMState.CODES_GENERATION: FSMState.COMPLETED,
}
# Build transitions for all fields (even if user-provided, we still need to generate field name)
# Field order: bpm -> caption -> duration -> genres -> keyscale -> language -> timesignature
# BPM field: NAME -> VALUE -> next field (caption or duration)
self.next_state[FSMState.BPM_NAME] = FSMState.BPM_VALUE
self.next_state[FSMState.BPM_VALUE] = self._get_next_field_state("bpm")
# Caption field (only if not skipped): NAME -> VALUE -> next field (duration)
if not self.skip_caption:
self.next_state[FSMState.CAPTION_NAME] = FSMState.CAPTION_VALUE
self.next_state[FSMState.CAPTION_VALUE] = self._get_next_field_state("caption")
# Duration field: NAME -> VALUE -> next field
self.next_state[FSMState.DURATION_NAME] = FSMState.DURATION_VALUE
self.next_state[FSMState.DURATION_VALUE] = self._get_next_field_state("duration")
# Genres field (only if not skipped): NAME -> VALUE -> next field
if not self.skip_genres:
self.next_state[FSMState.GENRES_NAME] = FSMState.GENRES_VALUE
self.next_state[FSMState.GENRES_VALUE] = self._get_next_field_state("genres")
# Keyscale field: NAME -> VALUE -> next field (language or timesignature)
self.next_state[FSMState.KEYSCALE_NAME] = FSMState.KEYSCALE_VALUE
self.next_state[FSMState.KEYSCALE_VALUE] = self._get_next_field_state("keyscale")
# Language field (only if not skipped): NAME -> VALUE -> next field (timesignature)
if not self.skip_language:
self.next_state[FSMState.LANGUAGE_NAME] = FSMState.LANGUAGE_VALUE
self.next_state[FSMState.LANGUAGE_VALUE] = self._get_next_field_state("language")
# Timesignature field: NAME -> VALUE -> THINK_END_TAG
self.next_state[FSMState.TIMESIG_NAME] = FSMState.TIMESIG_VALUE
self.next_state[FSMState.TIMESIG_VALUE] = FSMState.THINK_END_TAG
def set_skip_genres(self, skip: bool):
"""Set whether to skip genres generation and rebuild state transitions."""
self.skip_genres = skip
self._build_state_transitions()
def set_skip_caption(self, skip: bool):
"""Set whether to skip caption generation and rebuild state transitions."""
self.skip_caption = skip
self._build_state_transitions()
def set_skip_language(self, skip: bool):
"""Set whether to skip language generation and rebuild state transitions."""
self.skip_language = skip
self._build_state_transitions()
@staticmethod
def postprocess_caption(caption: str) -> str:
"""
Post-process caption to remove YAML multi-line formatting.
Converts YAML-style multi-line text (with newlines and leading spaces)
to a single-line string.
Example:
Input: "An emotional ballad.\\n The track opens with piano.\\n More text."
Output: "An emotional ballad. The track opens with piano. More text."
Args:
caption: Raw caption text with possible YAML formatting
Returns:
Clean single-line caption
"""
if not caption:
return caption
# Split by newlines
lines = caption.split('\n')
# Process each line: strip leading/trailing whitespace
cleaned_lines = []
for line in lines:
stripped = line.strip()
if stripped:
cleaned_lines.append(stripped)
# Join with single space
return ' '.join(cleaned_lines)
def set_stop_at_reasoning(self, stop: bool):
"""
Set whether to stop generation after tag.
Args:
stop: If True, generation will stop immediately after tag is generated.
If False, generation continues to codes generation phase.
"""
self.stop_at_reasoning = stop
def set_generation_phase(self, phase: str):
"""
Set the generation phase.
Args:
phase: "cot" for CoT metadata generation, "codes" for audio codes generation,
or "understand" for audio understanding (codes → metadata + lyrics).
When phase is "codes" and the input prompt already contains ,
the FSM will skip metadata generation and go directly to codes generation.
When phase is "understand", generate CoT metadata then free-form lyrics.
"""
if phase not in ("cot", "codes", "understand"):
raise ValueError(f"Invalid generation phase: {phase!r}. Must be 'cot', 'codes', or 'understand'")
self.generation_phase = phase
def set_user_metadata(self, metadata: Optional[Dict[str, Optional[str]]] = None):
"""
Set user-provided metadata fields. Fields that are provided will be used directly
instead of generating. Fields that are None will be generated.
Args:
metadata: Dictionary with optional fields:
- "bpm": Optional[str] - e.g., "120"
- "caption": Optional[str] - e.g., "A melodic piano piece..."
- "duration": Optional[str] - e.g., "234"
- "keyscale": Optional[str] - e.g., "G major"
- "language": Optional[str] - e.g., "en"
- "timesignature": Optional[str] - e.g., "4"
- "genres": Optional[str] - e.g., "Pop Rock"
If None, clears all user-provided metadata.
"""
if metadata is None:
metadata = {}
# Update user-provided metadata
for field in ["bpm", "caption", "duration", "keyscale", "language", "timesignature", "genres"]:
if field in metadata:
self.user_provided_metadata[field] = metadata[field]
else:
self.user_provided_metadata[field] = None
# Rebuild state transitions to skip provided fields
self._build_state_transitions()
if self.debug:
provided_fields = [k for k, v in self.user_provided_metadata.items() if v is not None]
if provided_fields:
logger.debug(f"User provided metadata fields: {provided_fields}")
else:
logger.debug("No user-provided metadata, all fields will be generated")
def _precompute_tokens(self):
"""Pre-compute commonly used token IDs for efficiency."""
# Digit tokens (0-9)
self.digit_tokens = {}
for d in range(10):
tokens = self.tokenizer.encode(str(d), add_special_tokens=False)
if tokens:
self.digit_tokens[d] = tokens[-1] # Take last token (in case of prefix)
# Newline token
newline_tokens = self.tokenizer.encode("\n", add_special_tokens=False)
self.newline_token = newline_tokens[-1] if newline_tokens else None
# Note tokens for keyscale (A-G)
self.note_tokens = {}
for note in KEYSCALE_NOTES:
tokens = self.tokenizer.encode(note, add_special_tokens=False)
if tokens:
self.note_tokens[note] = tokens[-1]
# Sharp/flat tokens
self.sharp_tokens = []
for s in ["#", "♯"]:
tokens = self.tokenizer.encode(s, add_special_tokens=False)
if tokens:
self.sharp_tokens.append(tokens[-1])
self.flat_tokens = []
for f in ["b", "♭"]:
tokens = self.tokenizer.encode(f, add_special_tokens=False)
if tokens:
self.flat_tokens.append(tokens[-1])
# Space token
space_tokens = self.tokenizer.encode(" ", add_special_tokens=False)
self.space_token = space_tokens[-1] if space_tokens else None
# Major/minor tokens (we'll encode the full words)
self.major_start_tokens = []
self.minor_start_tokens = []
for prefix in ["m", "M"]:
tokens = self.tokenizer.encode(prefix, add_special_tokens=False)
if tokens:
if prefix.lower() == "m":
self.minor_start_tokens.append(tokens[-1])
self.major_start_tokens.append(tokens[-1]) # "major" also starts with m
# Vocab size
self.vocab_size = len(self.tokenizer)
# Comma token for multi-genre support
comma_tokens = self.tokenizer.encode(",", add_special_tokens=False)
self.comma_token = comma_tokens[-1] if comma_tokens else None
# EOS token for duration-constrained codes generation
self.eos_token_id = self.tokenizer.eos_token_id
# Period token for caption field transition logic
period_tokens = self.tokenizer.encode(".", add_special_tokens=False)
self.period_token = period_tokens[-1] if period_tokens else None
# Backtick tokens for blocking code blocks in caption
backtick_tokens = self.tokenizer.encode("`", add_special_tokens=False)
self.backtick_token = backtick_tokens[-1] if backtick_tokens else None
# Valid language codes (ISO 639-1 and common variants)
self.valid_languages = VALID_LANGUAGES
# Precompute audio code token IDs (tokens matching <|audio_code_\d+|>)
# These should be blocked during caption generation
self.audio_code_token_ids: Set[int] = set()
self._precompute_audio_code_tokens()
# Precompute audio code mask for efficient blocking (O(1) instead of O(n))
# This mask will be added to scores during caption generation
self.audio_code_mask: Optional[torch.Tensor] = None
# Inverse mask: block all non-audio-code tokens (for CODES_GENERATION state)
self.non_audio_code_mask: Optional[torch.Tensor] = None
self._build_audio_code_mask()
# Build valid keyscales set (prefix tree will be built after _char_to_tokens is initialized)
# 7 notes × 5 accidentals (none, #, b, ♯, ♭) × 2 modes = 70 valid combinations
self.valid_keyscales = VALID_KEYSCALES.copy()
# keyscale_prefix_tree will be built in _precompute_char_token_mapping() after _char_to_tokens is ready
# Numeric prefix trees will be built after field_specs is defined
def _precompute_audio_code_tokens(self):
"""
Precompute audio code token IDs (tokens matching <|audio_code_\\d+|>).
These tokens should be blocked during caption generation.
Only tokens with code values in range [0, MAX_AUDIO_CODE] are included.
"""
import re
audio_code_pattern = re.compile(r'^<\|audio_code_(\d+)\|>$')
invalid_tokens_count = 0
# Iterate through vocabulary to find audio code tokens
for token_id in range(self.vocab_size):
try:
token_text = self.tokenizer.decode([token_id])
match = audio_code_pattern.match(token_text)
if match:
# Extract code value from token text
code_value = int(match.group(1))
# Only add tokens with valid code values (0-63999)
if 0 <= code_value <= MAX_AUDIO_CODE:
self.audio_code_token_ids.add(token_id)
else:
invalid_tokens_count += 1
if self.debug:
logger.debug(f"Skipping audio code token {token_id} with invalid code value {code_value} (max: {MAX_AUDIO_CODE})")
except Exception:
continue
if invalid_tokens_count > 0:
logger.warning(f"Found {invalid_tokens_count} audio code tokens with values outside valid range [0, {MAX_AUDIO_CODE}]")
# Log warning if no valid tokens found (this would prevent code generation)
if len(self.audio_code_token_ids) == 0:
logger.warning(f"No valid audio code tokens found in vocabulary (range [0, {MAX_AUDIO_CODE}]). Code generation may fail.")
elif self.debug:
logger.debug(f"Found {len(self.audio_code_token_ids)} valid audio code tokens (range [0, {MAX_AUDIO_CODE}])")
def _extract_code_from_token(self, token_id: int) -> Optional[int]:
"""
Extract audio code value from a token ID.
Args:
token_id: Token ID to extract code value from
Returns:
Code value if token is a valid audio code token, None otherwise
"""
import re
audio_code_pattern = re.compile(r'^<\|audio_code_(\d+)\|>$')
try:
token_text = self.tokenizer.decode([token_id])
match = audio_code_pattern.match(token_text)
if match:
return int(match.group(1))
except Exception:
pass
return None
def _build_audio_code_mask(self):
"""
Build a precomputed mask tensor for blocking audio code tokens.
This mask can be added to scores in O(1) time instead of O(n) loop.
The mask is [1, vocab_size] tensor with -inf at audio code token positions.
Also builds the inverse mask (non_audio_code_mask) for CODES_GENERATION state,
which blocks all non-audio-code tokens.
"""
if not self.audio_code_token_ids:
self.audio_code_mask = None
self.non_audio_code_mask = None
return
# Create mask tensor: 0 everywhere, -inf at audio code positions
# Use float32 for compatibility with most model dtypes
mask = torch.zeros(1, self.vocab_size, dtype=torch.float32)
# Convert set to list for indexing
audio_code_indices = list(self.audio_code_token_ids)
# Set -inf at audio code token positions
mask[0, audio_code_indices] = float('-inf')
self.audio_code_mask = mask
# Build inverse mask: -inf everywhere EXCEPT at audio code positions
# This is used in CODES_GENERATION state to only allow audio codes
inverse_mask = torch.full((1, self.vocab_size), float('-inf'), dtype=torch.float32)
inverse_mask[0, audio_code_indices] = 0
# Also allow EOS token in codes generation (will be controlled by duration constraint)
if self.eos_token_id is not None:
inverse_mask[0, self.eos_token_id] = 0
self.non_audio_code_mask = inverse_mask
if self.debug:
logger.debug(f"Built audio code masks for {len(self.audio_code_token_ids)} tokens")
def _apply_whitelist_inplace(self, scores: torch.Tensor, allowed_tokens: List[int]) -> None:
"""
Apply whitelist constraint inplace: only allow specified tokens, block all others.
This is more efficient than creating a mask tensor because:
1. No memory allocation for mask
2. No tensor addition operation
Args:
scores: [1, vocab_size] scores tensor to modify inplace
allowed_tokens: List of token IDs to allow (all others will be set to -inf)
"""
if not allowed_tokens:
# No tokens allowed, set all to -inf
scores.fill_(float('-inf'))
return
# Save the original values of allowed tokens
allowed_indices = torch.tensor(allowed_tokens, device=scores.device, dtype=torch.long)
saved_values = scores[0, allowed_indices].clone()
# Set all scores to -inf
scores.fill_(float('-inf'))
# Restore allowed token values
scores[0, allowed_indices] = saved_values
def _build_keyscale_prefix_tree(self) -> Dict[Tuple[int, ...], Set[int]]:
"""
Build keyscale prefix to allowed tokens mapping based on ACTUAL tokenization.
IMPORTANT: Uses token ID sequences as keys, NOT strings, to avoid tokenization mismatches.
CRITICAL FIX: The tokenizer may merge the context's trailing space into the next token.
For example:
- "keyscale: " tokenizes to [10563, 2246, 25, 220] -> ['keys', 'cale', ':', ' ']
- "keyscale: G major" tokenizes to [10563, 2246, 25, 479, 3598] -> ['keys', 'cale', ':', ' G', ' major']
The space ' ' (220) is merged into ' G' (479), so we can't use simple slicing.
Strategy:
1. For each keyscale (e.g., "G major"), encode the FULL string "keyscale: G major"
2. Tokenize to get: [10563, 2246, 25, 479, 3598] -> ['keys', 'cale', ':', ' G', ' major']
3. Find where context prefix ends by matching token sequences (handling space merging)
4. Extract keyscale value tokens: [479, 3598] (for "G major")
5. Build prefix tree using token ID sequences as keys
This ensures we get the exact tokenization that occurs during generation.
"""
prefix_to_tokens: Dict[Tuple[int, ...], Set[int]] = {}
# Context prefix that appears before keyscale value
# IMPORTANT: The state machine generates "keyscale:" (no space), but when tokenizing
# the full string "keyscale: G major", the tokenizer includes space, so we need to
# match the actual tokenization behavior.
#
# Strategy:
# 1. Use "keyscale:" (no space) to match the state machine's output
# 2. But when building prefix tree, use "keyscale: " (with space) + keyscale to match actual tokenization
context_prefix_for_matching = "keyscale:" # What state machine generates
context_prefix_for_tokenization = "keyscale: " # What tokenizer sees in full string
# First, tokenize the context (without space) to know its token sequence for matching
context_token_ids = self.tokenizer.encode(context_prefix_for_matching, add_special_tokens=False)
if self.debug:
context_tokens_str = [self.tokenizer.decode([t]) for t in context_token_ids]
logger.debug(f"Context for matching 'keyscale:' tokenizes to {context_token_ids} -> {context_tokens_str}")
# For each valid keyscale, encode full string and extract value tokens
for keyscale in self.valid_keyscales:
# Step 1: Encode full string "keyscale: {keyscale}" (with space, as tokenizer sees it)
full_text = context_prefix_for_tokenization + keyscale
full_token_ids = self.tokenizer.encode(full_text, add_special_tokens=False)
# Step 2: Find where context ends in full_token_ids
# We match using context_prefix_for_matching ("keyscale:") token sequence
# because that's what the state machine actually generates
context_end_idx = None
# Try exact prefix match using context_prefix_for_matching token sequence
if len(full_token_ids) >= len(context_token_ids):
if full_token_ids[:len(context_token_ids)] == context_token_ids:
context_end_idx = len(context_token_ids)
if context_end_idx is None:
if self.debug:
logger.warning(f"Could not find context prefix in full tokenization of '{full_text}', skipping")
continue
# Step 3: Extract keyscale value tokens (everything after context)
keyscale_token_ids = full_token_ids[context_end_idx:]
# Step 4: Verify we extracted some tokens (sanity check)
if not keyscale_token_ids:
if self.debug:
logger.warning(f"No tokens extracted for keyscale '{keyscale}', skipping")
continue
# Step 5: Verify first token is a note (A-G)
# This is critical: the first token of keyscale value must be a note
first_token_id = keyscale_token_ids[0]
first_token_str = self.tokenizer.decode([first_token_id])
# Check if first token starts with a note (A-G, case insensitive, with optional leading space)
first_char = first_token_str.lstrip()[0].upper() if first_token_str.lstrip() else ""
if first_char not in "ABCDEFG":
# This keyscale's first token is not a note - skip it
if self.debug:
logger.debug(f"Skipping keyscale '{keyscale}': first token is '{first_token_str}' (id={first_token_id}), not a note")
continue
# Step 6: Build prefix mappings from keyscale value tokens
# Use token ID sequences as keys (not strings) to avoid tokenization mismatches
for i in range(len(keyscale_token_ids) + 1):
# Current token sequence prefix (empty tuple for start)
token_prefix = tuple(keyscale_token_ids[:i])
if token_prefix not in prefix_to_tokens:
prefix_to_tokens[token_prefix] = set()
if i < len(keyscale_token_ids):
# Add next token as allowed for current prefix
next_token_id = keyscale_token_ids[i]
prefix_to_tokens[token_prefix].add(next_token_id)
else:
# Complete keyscale should allow newline
if self.newline_token:
prefix_to_tokens[token_prefix].add(self.newline_token)
if self.debug:
logger.debug(f"Built keyscale prefix tree with {len(prefix_to_tokens)} token sequence prefixes")
# Check empty prefix (start of keyscale value)
empty_prefix = tuple()
if empty_prefix in prefix_to_tokens:
first_tokens = prefix_to_tokens[empty_prefix]
decoded_first = [(t, repr(self.tokenizer.decode([t]))) for t in sorted(first_tokens)]
logger.debug(f"First tokens allowed (empty prefix): {decoded_first}")
return prefix_to_tokens
def _build_numeric_prefix_tree(
self,
valid_values: List[str],
context_prefix_for_matching: str = "",
context_prefix_for_tokenization: str = ""
) -> Dict[Tuple[int, ...], Set[int]]:
"""
Build prefix tree for numeric field based on actual tokenization with context.
IMPORTANT: Uses token ID sequences as keys, NOT strings, to avoid tokenization mismatches.
Args:
valid_values: List of valid numeric strings (e.g., ["30", "31", ..., "300"])
context_prefix_for_matching: Context string that state machine generates (e.g., "bpm:") - no space
context_prefix_for_tokenization: Context string for tokenization (e.g., "bpm: ") - with space
Returns:
Dict mapping token ID sequence prefix -> set of allowed token IDs
"""
prefix_to_tokens: Dict[Tuple[int, ...], Set[int]] = {}
# Encode context for matching (what state machine generates, no space)
context_token_ids = self.tokenizer.encode(context_prefix_for_matching, add_special_tokens=False) if context_prefix_for_matching else []
# For each valid value, encode it with context and build prefix mappings
for value_str in valid_values:
# Encode value WITH context (with space) to match actual tokenization
full_text = context_prefix_for_tokenization + value_str
token_ids = self.tokenizer.encode(full_text, add_special_tokens=False)
# Find where context ends in full_token_ids using context_prefix_for_matching token sequence
context_end_idx = None
if len(token_ids) >= len(context_token_ids):
if token_ids[:len(context_token_ids)] == context_token_ids:
context_end_idx = len(context_token_ids)
if context_end_idx is None:
if self.debug:
logger.warning(f"Could not find context prefix in full tokenization of '{full_text}', skipping")
continue
# Extract only tokens that belong to the value itself (skip context tokens)
value_token_ids = token_ids[context_end_idx:]
# Build prefix mappings using token ID sequences as keys
for i in range(len(value_token_ids) + 1):
# Current token sequence prefix (empty tuple for start)
token_prefix = tuple(value_token_ids[:i])
if token_prefix not in prefix_to_tokens:
prefix_to_tokens[token_prefix] = set()
if i < len(value_token_ids):
# Add next token as allowed for current prefix
next_token_id = value_token_ids[i]
prefix_to_tokens[token_prefix].add(next_token_id)
else:
# Complete value should allow newline
if self.newline_token:
prefix_to_tokens[token_prefix].add(self.newline_token)
return prefix_to_tokens
def _build_language_prefix_tree(self) -> Dict[Tuple[int, ...], Set[int]]:
"""
Build language prefix to allowed tokens mapping based on ACTUAL tokenization.
Similar to keyscale prefix tree but for language codes.
Uses token ID sequences as keys, NOT strings, to avoid tokenization mismatches.
"""
prefix_to_tokens: Dict[Tuple[int, ...], Set[int]] = {}
context_prefix_for_matching = "language:"
context_prefix_for_tokenization = "language: "
context_token_ids = self.tokenizer.encode(context_prefix_for_matching, add_special_tokens=False)
if self.debug:
context_tokens_str = [self.tokenizer.decode([t]) for t in context_token_ids]
logger.debug(f"Context for matching 'language:' tokenizes to {context_token_ids} -> {context_tokens_str}")
for lang in self.valid_languages:
full_text = context_prefix_for_tokenization + lang
full_token_ids = self.tokenizer.encode(full_text, add_special_tokens=False)
context_end_idx = None
if len(full_token_ids) >= len(context_token_ids):
if full_token_ids[:len(context_token_ids)] == context_token_ids:
context_end_idx = len(context_token_ids)
if context_end_idx is None:
if self.debug:
logger.warning(f"Could not find context prefix in full tokenization of '{full_text}', skipping")
continue
lang_token_ids = full_token_ids[context_end_idx:]
if not lang_token_ids:
if self.debug:
logger.warning(f"No tokens extracted for language '{lang}', skipping")
continue
for i in range(len(lang_token_ids) + 1):
token_prefix = tuple(lang_token_ids[:i])
if token_prefix not in prefix_to_tokens:
prefix_to_tokens[token_prefix] = set()
if i < len(lang_token_ids):
next_token_id = lang_token_ids[i]
prefix_to_tokens[token_prefix].add(next_token_id)
else:
if self.newline_token:
prefix_to_tokens[token_prefix].add(self.newline_token)
if self.debug:
logger.debug(f"Built language prefix tree with {len(prefix_to_tokens)} token sequence prefixes")
empty_prefix = tuple()
if empty_prefix in prefix_to_tokens:
first_tokens = prefix_to_tokens[empty_prefix]
decoded_first = [(t, repr(self.tokenizer.decode([t]))) for t in sorted(first_tokens)]
logger.debug(f"First tokens allowed for language (empty prefix): {decoded_first}")
return prefix_to_tokens
def diagnose_keyscale_prefix_tree(self):
"""
Diagnose the keyscale prefix tree to help debug generation bias.
Call this method to print detailed information about allowed tokens at each prefix.
"""
print("=" * 60)
print("KEYSCALE PREFIX TREE DIAGNOSIS")
print("=" * 60)
# Check empty prefix (first token)
if "" in self.keyscale_prefix_tree:
first_tokens = self.keyscale_prefix_tree[""]
print(f"\n[Empty prefix] Allowed first tokens ({len(first_tokens)} total):")
for t in sorted(first_tokens):
decoded = self.tokenizer.decode([t])
print(f" Token {t}: {repr(decoded)}")
else:
print("\nWARNING: Empty prefix not in tree!")
# Check some common prefixes
test_prefixes = ["A", "B", "C", "D", "E", "F", "G"]
for prefix in test_prefixes:
# Try both with and without potential tokenizer artifacts
for test_key in [prefix, prefix + " "]:
if test_key in self.keyscale_prefix_tree:
tokens = self.keyscale_prefix_tree[test_key]
print(f"\n[Prefix {repr(test_key)}] Allowed tokens ({len(tokens)}):")
for t in sorted(tokens):
decoded = self.tokenizer.decode([t])
print(f" Token {t}: {repr(decoded)}")
# Show some complete keyscales that should be valid
print(f"\n[Valid keyscales] Total: {len(self.valid_keyscales)}")
sample = sorted(list(self.valid_keyscales))[:10]
for ks in sample:
print(f" {repr(ks)}")
print("=" * 60)
def _load_genres_vocab(self):
"""
Load genres vocabulary from file. Supports hot reload by checking file mtime.
File format: one genre per line, lines starting with # are comments.
"""
if not os.path.exists(self.genres_vocab_path):
if self.debug:
logger.debug(f"Genres vocab file not found: {self.genres_vocab_path}")
return
try:
mtime = os.path.getmtime(self.genres_vocab_path)
if mtime <= self.genres_vocab_mtime:
return # File hasn't changed
with open(self.genres_vocab_path, 'r', encoding='utf-8') as f:
genres = []
for line in f:
line = line.strip()
if line and not line.startswith('#'):
genres.append(line.lower())
self.genres_vocab = genres
self.genres_vocab_mtime = mtime
self._build_genres_trie()
if self.debug:
logger.debug(f"Loaded {len(self.genres_vocab)} genres from {self.genres_vocab_path}")
except Exception as e:
logger.warning(f"Failed to load genres vocab: {e}")
def _build_genres_trie(self):
"""
Build a trie (prefix tree) from genres vocabulary for efficient prefix matching.
Each node is a dict with:
- '_end': True if this node represents a complete genre
- other keys: next characters in the trie
"""
self.genres_trie = {}
for genre in self.genres_vocab:
node = self.genres_trie
for char in genre:
if char not in node:
node[char] = {}
node = node[char]
node['_end'] = True # Mark end of a complete genre
if self.debug:
logger.debug(f"Built genres trie with {len(self.genres_vocab)} entries")
def _extract_caption_genres(self, caption: str):
"""
Extract genres from the user's caption that match entries in the vocabulary.
This creates a smaller trie for faster and more relevant genre generation.
Strategy (optimized - O(words * max_genre_len) instead of O(vocab_size)):
1. Extract words/phrases from caption
2. For each word, use trie to find all vocab entries that START with this word
3. Build a separate trie from matched genres
"""
if not caption or not self.genres_vocab:
return
caption_lower = caption.lower()
matched_genres = set()
# Extract words from caption (split by common delimiters)
import re
words = re.split(r'[,\s\-_/\\|]+', caption_lower)
words = [w.strip() for w in words if w.strip() and len(w.strip()) >= 2]
# For each word, find genres in trie that start with this word
for word in words:
# Find all genres starting with this word using trie traversal
node = self._get_genres_trie_node(word)
if node is not None:
# Collect all complete genres under this node
self._collect_complete_genres(node, word, matched_genres)
# Also check if any word appears as a substring in short genres (< 20 chars)
# This is a quick check for common single-word genres
genres_set = set(self.genres_vocab)
for word in words:
if word in genres_set:
matched_genres.add(word)
if not matched_genres:
if self.debug:
logger.debug(f"No genres matched in caption, using full vocab")
return
# Build a trie from matched genres
self.caption_matched_genres = list(matched_genres)
self.caption_genres_trie = {}
for genre in matched_genres:
node = self.caption_genres_trie
for char in genre:
if char not in node:
node[char] = {}
node = node[char]
node['_end'] = True
if self.debug:
logger.debug(f"Matched {len(matched_genres)} genres from caption: {list(matched_genres)[:5]}...")
def _collect_complete_genres(self, node: Dict, prefix: str, result: set, max_depth: int = 50):
"""
Recursively collect all complete genres under a trie node.
Limited depth to avoid too many matches.
"""
if max_depth <= 0:
return
if node.get('_end', False):
result.add(prefix)
# Limit total collected genres to avoid slowdown
if len(result) >= 100:
return
for char, child_node in node.items():
if char not in ('_end', '_tokens'):
self._collect_complete_genres(child_node, prefix + char, result, max_depth - 1)
def _precompute_char_token_mapping(self):
"""
Precompute mapping from characters to token IDs and token decoded texts.
This allows O(1) lookup instead of calling tokenizer.encode()/decode() at runtime.
Time complexity: O(vocab_size) - runs once during initialization
Note: Many subword tokenizers (like Qwen) add space prefixes to tokens.
We need to handle both the raw first char and the first non-space char.
"""
self._char_to_tokens: Dict[str, set] = {}
self._token_to_text: Dict[int, str] = {} # Precomputed decoded text for each token
# For each token in vocabulary, get its decoded text
for token_id in range(self.vocab_size):
try:
text = self.tokenizer.decode([token_id])
if not text:
continue
# Store the decoded text (normalized to lowercase)
# Keep leading spaces for proper concatenation (e.g., " rock" in "pop rock")
# Only rstrip trailing whitespace, unless it's a pure whitespace token
text_lower = text.lower()
if text_lower.strip(): # Has non-whitespace content
normalized_text = text_lower.rstrip()
else: # Pure whitespace token
normalized_text = " " # Normalize to single space
self._token_to_text[token_id] = normalized_text
# Map first character (including space) to this token
first_char = text[0].lower()
if first_char not in self._char_to_tokens:
self._char_to_tokens[first_char] = set()
self._char_to_tokens[first_char].add(token_id)
# Also map first non-space character to this token
# This handles tokenizers that add space prefixes (e.g., " pop" -> maps to 'p')
stripped_text = text.lstrip()
if stripped_text and stripped_text != text:
first_nonspace_char = stripped_text[0].lower()
if first_nonspace_char not in self._char_to_tokens:
self._char_to_tokens[first_nonspace_char] = set()
self._char_to_tokens[first_nonspace_char].add(token_id)
except Exception:
continue
if self.debug:
logger.debug(f"Precomputed char->token mapping for {len(self._char_to_tokens)} unique characters")
def _try_reload_genres_vocab(self):
"""Check if genres vocab file has been updated and reload if necessary."""
if not os.path.exists(self.genres_vocab_path):
return
try:
mtime = os.path.getmtime(self.genres_vocab_path)
if mtime > self.genres_vocab_mtime:
self._load_genres_vocab()
except Exception:
pass # Ignore errors during hot reload check
def _get_genres_trie_node(self, prefix: str) -> Optional[Dict]:
"""
Get the trie node for a given prefix.
Returns None if the prefix is not valid (no genres start with this prefix).
"""
node = self.genres_trie
for char in prefix.lower():
if char not in node:
return None
node = node[char]
return node
def _is_complete_genre(self, text: str) -> bool:
"""Check if the given text is a complete genre in the vocabulary."""
node = self._get_genres_trie_node(text.strip())
return node is not None and node.get('_end', False)
def _get_trie_node_from_trie(self, trie: Dict, prefix: str) -> Optional[Dict]:
"""Get a trie node from a specific trie (helper for caption vs full trie)."""
node = trie
for char in prefix.lower():
if char not in node:
return None
node = node[char]
return node
def _get_allowed_genres_tokens(self) -> List[int]:
"""
Get allowed tokens for genres field based on trie matching.
The entire genres string (including commas) must match a complete entry in the vocab.
For example, if vocab contains "pop, rock, jazz", the generated string must exactly
match that entry - we don't treat commas as separators for individual genres.
Strategy:
1. If caption-matched genres exist, use that smaller trie first (faster + more relevant)
2. If no caption matches or prefix not in caption trie, fallback to full vocab trie
3. Get valid next characters from current trie node
4. For each candidate token, verify the full decoded text forms a valid trie prefix
"""
if not self.genres_vocab:
# No vocab loaded, allow all except newline if empty
return []
# Use the full accumulated value (don't split by comma - treat as single entry)
accumulated = self.accumulated_value.lower()
current_genre_prefix = accumulated.strip()
# Determine which trie to use: caption-matched (priority) or full vocab (fallback)
use_caption_trie = False
current_node = None
# Try caption-matched trie first if available
if self.caption_genres_trie:
if current_genre_prefix == "":
current_node = self.caption_genres_trie
use_caption_trie = True
else:
current_node = self._get_trie_node_from_trie(self.caption_genres_trie, current_genre_prefix)
if current_node is not None:
use_caption_trie = True
# Fallback to full vocab trie
if current_node is None:
if current_genre_prefix == "":
current_node = self.genres_trie
else:
current_node = self._get_genres_trie_node(current_genre_prefix)
if current_node is None:
# Invalid prefix, force newline to end
if self.newline_token:
return [self.newline_token]
return []
# Get valid next characters from trie node
valid_next_chars = set(k for k in current_node.keys() if k not in ('_end', '_tokens'))
# If current value is a complete genre, allow newline to end
is_complete = current_node.get('_end', False)
if not valid_next_chars:
# No more characters to match, only allow newline if complete
allowed = set()
if is_complete and self.newline_token:
allowed.add(self.newline_token)
return list(allowed)
# Collect candidate tokens based on first character
candidate_tokens = set()
for char in valid_next_chars:
if char in self._char_to_tokens:
candidate_tokens.update(self._char_to_tokens[char])
# Select the appropriate trie for validation
active_trie = self.caption_genres_trie if use_caption_trie else self.genres_trie
# Validate each candidate token: check if prefix + decoded_token is a valid trie prefix
allowed = set()
for token_id in candidate_tokens:
# Use precomputed decoded text (already normalized)
decoded_normalized = self._token_to_text.get(token_id, "")
if not decoded_normalized or not decoded_normalized.strip():
# Token decodes to empty or only whitespace - allow if space/comma is a valid next char
if ' ' in valid_next_chars or ',' in valid_next_chars:
allowed.add(token_id)
continue
# Build new prefix by appending decoded token
# Handle space-prefixed tokens (e.g., " rock" from "pop rock")
if decoded_normalized.startswith(' ') or decoded_normalized.startswith(','):
# Token has leading space/comma - append directly
new_prefix = current_genre_prefix + decoded_normalized
else:
new_prefix = current_genre_prefix + decoded_normalized
# Check if new_prefix is a valid prefix in the active trie
new_node = self._get_trie_node_from_trie(active_trie, new_prefix)
if new_node is not None:
allowed.add(token_id)
# If current value is a complete genre, also allow newline
if is_complete and self.newline_token:
allowed.add(self.newline_token)
return list(allowed)
def reset(self):
"""Reset the processor state for a new generation."""
self.state = FSMState.THINK_TAG
self.position_in_state = 0
self.accumulated_value = "" # Legacy, kept for compatibility
self.accumulated_token_ids = [] # Reset token ID sequence
self.codes_count = 0 # Reset codes counter
self.user_field_token_queue = [] # Reset user field token queue
self.current_user_field = None # Reset current user field
self.caption_after_newline = False # Reset caption newline tracking
self.caption_token_count = 0 # Reset caption token count
self.caption_ending = False # Reset caption ending tracking
self.pending_field_name = "" # Reset pending field name
def set_target_duration(self, duration: Optional[float]):
"""
Set the target duration for codes generation.
Args:
duration: Target duration in seconds. If None, no duration constraint is applied.
5 codes = 1 second, so target_codes = duration * 5.
"""
self.target_duration = duration
if duration is not None and duration > 0:
self.target_codes = int(duration * 5)
if self.debug:
logger.debug(f"Set target duration: {duration}s -> {self.target_codes} codes")
else:
self.target_codes = None
if self.debug:
logger.debug("Target duration cleared, no duration constraint")
def _get_allowed_tokens_for_fixed_string(self, fixed_str: str) -> List[int]:
"""
Get the token IDs that can continue the fixed string from current position.
Returns list of allowed token IDs.
Strategy: Find the longest prefix that encodes to a single token, and return that token.
This ensures we generate by tokens, not character-by-character.
"""
remaining = fixed_str[self.position_in_state:]
if not remaining:
return []
if self.debug:
logger.debug(f"_get_allowed_tokens_for_fixed_string: fixed_str={repr(fixed_str)}, position_in_state={self.position_in_state}, remaining={repr(remaining)}")
# Try encoding progressively longer prefixes, from longest to shortest
# We want to find the longest prefix that encodes to a single token
best_token = None
best_prefix_len = 0
# First pass: find the longest prefix that encodes to exactly one token
for end in range(len(remaining), 0, -1): # Start from longest prefix
prefix = remaining[:end]
tokens = self.tokenizer.encode(prefix, add_special_tokens=False)
if tokens and len(tokens) == 1:
# Found a prefix that encodes to a single token
# Use this one (longest match)
best_token = tokens[0]
best_prefix_len = end
if self.debug:
logger.debug(f"Found single-token match: prefix={repr(prefix)}, token_id={best_token}, token_text={repr(self.tokenizer.decode([best_token]))}")
break
# If we found a single-token match, return it (this is the preferred case)
if best_token is not None:
return [best_token]
# Fallback: if no single-token match found, collect all possible first tokens
# This handles edge cases where the string might need multiple tokens
# But we still want to prefer longer matches
# IMPORTANT: Only consider tokens that actually match the beginning of remaining string
# Decode each candidate token and verify it matches the prefix
allowed_tokens = {}
for end in range(1, min(len(remaining) + 1, 20)): # Limit search to avoid too many iterations
prefix = remaining[:end]
tokens = self.tokenizer.encode(prefix, add_special_tokens=False)
if tokens:
first_token = tokens[0]
# Verify: decode the token and check it matches the prefix start
decoded_token = self.tokenizer.decode([first_token])
# Normalize both for comparison (strip and lower)
normalized_prefix = prefix.lstrip().lower()
normalized_decoded = decoded_token.lstrip().lower()
# Check if decoded token matches the prefix start (allowing for space prefixes)
if normalized_decoded.startswith(normalized_prefix) or normalized_prefix.startswith(normalized_decoded):
# Store the longest prefix length for each token
if first_token not in allowed_tokens or end > allowed_tokens[first_token]:
allowed_tokens[first_token] = end
# Return tokens sorted by prefix length (longest first)
# This ensures we prefer longer matches
sorted_tokens = sorted(allowed_tokens.items(), key=lambda x: x[1], reverse=True)
result = [token for token, _ in sorted_tokens] if sorted_tokens else []
if self.debug:
logger.debug(f"Fallback: returning {len(result)} tokens: {[(t, repr(self.tokenizer.decode([t]))) for t in result[:5]]}")
if result:
logger.debug(f"Fixed string: {repr(fixed_str)}, position: {self.position_in_state}, remaining: {repr(remaining)}")
return result
def _get_allowed_digit_tokens(self, min_val: int, max_val: int) -> List[int]:
"""
Get allowed digit tokens based on accumulated value and range constraints.
Uses early-blocking to prevent out-of-range values.
"""
if not self.accumulated_value:
# First digit: determine valid starting digits
allowed_digits = set()
for v in range(min_val, max_val + 1):
allowed_digits.add(int(str(v)[0]))
return [self.digit_tokens[d] for d in allowed_digits if d in self.digit_tokens]
current = int(self.accumulated_value)
allowed = []
for d in range(10):
new_value = int(self.accumulated_value + str(d))
# Check if this digit could lead to a valid final value
# A digit is valid if:
# 1. new_value <= max_val (not already exceeded)
# 2. new_value could potentially reach >= min_val
# (i.e., new_value * 10^k >= min_val for some k >= 0)
if new_value > max_val:
continue # Already exceeded max
# Check if we can still reach min_val
# If new_value is already >= min_val, it's valid
# If new_value < min_val, we need more digits, but new_value * 10 must not exceed max
if new_value >= min_val:
allowed.append(d)
elif new_value * 10 <= max_val:
# Can add more digits
allowed.append(d)
return [self.digit_tokens[d] for d in allowed if d in self.digit_tokens]
def _get_allowed_numeric_tokens(self, prefix_tree: Dict[Tuple[int, ...], Set[int]]) -> List[int]:
"""
Get allowed tokens for numeric field using the precomputed prefix tree.
IMPORTANT: Uses token ID sequence as key (not string) to avoid tokenization mismatches.
Args:
prefix_tree: Precomputed prefix tree mapping token ID sequence -> set of allowed token IDs
Returns:
List of allowed token IDs for current accumulated_token_ids
"""
token_prefix = tuple(self.accumulated_token_ids)
if token_prefix in prefix_tree:
return list(prefix_tree[token_prefix])
# No valid continuation found - return empty list
# The caller will handle this by forcing newline to end the field
return []
def _should_end_numeric_field(self, logits: torch.Tensor, min_val: int, max_val: int) -> bool:
"""
Determine if we should end the current numeric field.
Returns True if P(newline) > P(any valid digit) AND current value is valid.
"""
if not self.accumulated_value:
return False
current = int(self.accumulated_value)
if current < min_val or current > max_val:
return False # Can't end yet, value not in range
# Get probabilities
probs = torch.softmax(logits, dim=-1)
newline_prob = probs[0, self.newline_token].item() if self.newline_token else 0
# Get max probability among valid digit tokens
allowed_digits = self._get_allowed_digit_tokens(min_val, max_val)
if not allowed_digits:
return True # No more digits possible, must end
max_digit_prob = max(probs[0, t].item() for t in allowed_digits)
if self.debug:
logger.debug(f"Numeric field decision: newline_prob={newline_prob:.4f}, max_digit_prob={max_digit_prob:.4f}")
return newline_prob > max_digit_prob
def _should_end_text_field(self, logits: torch.Tensor) -> bool:
"""
Determine if we should end a text field (genres).
Returns True if P(newline) > P(any other token) AND we have some content.
"""
if not self.accumulated_value.strip():
return False # Need at least some content
probs = torch.softmax(logits, dim=-1)
newline_prob = probs[0, self.newline_token].item() if self.newline_token else 0
# Get max probability among non-newline tokens
masked_probs = probs.clone()
if self.newline_token:
masked_probs[0, self.newline_token] = 0
max_other_prob = masked_probs[0].max().item()
return newline_prob > max_other_prob
def _get_allowed_keyscale_tokens(self) -> List[int]:
"""
Get allowed tokens for keyscale field using the precomputed prefix tree.
Uses token ID sequence as key (not string) to avoid tokenization mismatches.
"""
# Use token ID sequence as key
token_prefix = tuple(self.accumulated_token_ids)
if token_prefix in self.keyscale_prefix_tree:
return list(self.keyscale_prefix_tree[token_prefix])
# Fallback: if we somehow drifted off (shouldn't happen with constrained decoding),
# return empty to force newline logic or stop.
return []
def _is_keyscale_complete(self) -> bool:
"""
Check if keyscale value is complete and valid.
Uses token ID sequence to check if current prefix allows newline.
"""
token_prefix = tuple(self.accumulated_token_ids)
# If current token sequence prefix is in tree and allows newline, it's complete
if token_prefix in self.keyscale_prefix_tree:
return self.newline_token in self.keyscale_prefix_tree[token_prefix]
return False
def _get_allowed_language_tokens(self) -> List[int]:
"""
Get allowed tokens for language field using the precomputed prefix tree.
Uses token ID sequence as key (not string) to avoid tokenization mismatches.
Similar to keyscale.
"""
token_prefix = tuple(self.accumulated_token_ids)
if token_prefix in self.language_prefix_tree:
return list(self.language_prefix_tree[token_prefix])
# Fallback: no valid continuation found
return []
def _get_allowed_timesig_tokens(self) -> List[int]:
"""
Get allowed tokens for timesignature field using the precomputed prefix tree.
Uses token ID sequence as key (not string) to avoid tokenization mismatches.
"""
token_prefix = tuple(self.accumulated_token_ids)
if token_prefix in self.timesig_prefix_tree:
return list(self.timesig_prefix_tree[token_prefix])
# No valid continuation found - return empty list
# The caller will handle this by forcing newline to end the field
return []
def __call__(
self,
input_ids: torch.LongTensor,
scores: torch.FloatTensor,
) -> torch.FloatTensor:
"""
Apply constrained decoding by modifying logits.
Args:
input_ids: [batch_size, seq_len] input token IDs
scores: [batch_size, vocab_size] logits for next token
Returns:
Modified scores with invalid tokens masked to -inf and temperature scaling applied
"""
if not self.enabled:
return self._apply_temperature_scaling(scores)
if self.state == FSMState.COMPLETED:
# In understanding phase, block audio codes during lyrics generation (COMPLETED state)
if self.generation_phase == "understand" and self.audio_code_mask is not None:
# Move mask to same device/dtype as scores if needed
if self.audio_code_mask.device != scores.device or self.audio_code_mask.dtype != scores.dtype:
self.audio_code_mask = self.audio_code_mask.to(device=scores.device, dtype=scores.dtype)
scores = scores + self.audio_code_mask
return self._apply_temperature_scaling(scores)
# For codes phase, detect if input already contains and skip to CODES_GENERATION
if self.generation_phase == "codes" and self.state == FSMState.THINK_TAG:
# Check if input contains token sequence
if self._input_contains_think_end_tag(input_ids):
# Skip metadata generation, go directly to codes generation
self.state = FSMState.CODES_GENERATION
self.codes_count = 0
if self.debug:
logger.debug("Codes phase: detected in input, skipping to CODES_GENERATION")
if self.state == FSMState.CODES_GENERATION:
# Block all non-audio-code tokens (only allow audio codes and EOS)
# Note: audio_code_token_ids already contains only valid tokens (0-63999 range)
# because _precompute_audio_code_tokens() filters out invalid tokens during initialization
if self.non_audio_code_mask is not None:
# Move mask to same device/dtype as scores if needed
if self.non_audio_code_mask.device != scores.device or self.non_audio_code_mask.dtype != scores.dtype:
self.non_audio_code_mask = self.non_audio_code_mask.to(device=scores.device, dtype=scores.dtype)
scores = scores + self.non_audio_code_mask
# Apply duration constraint in codes generation phase
if self.target_codes is not None and self.eos_token_id is not None:
if self.codes_count < self.target_codes:
# Block EOS token until target codes count is reached
scores[:, self.eos_token_id] = float('-inf')
if self.debug:
logger.debug(f"Codes generation: {self.codes_count}/{self.target_codes}, blocking EOS")
else:
# Force EOS token when target codes count is reached - inplace
eos_scores = scores[:, self.eos_token_id].clone()
scores.fill_(float('-inf'))
scores[:, self.eos_token_id] = eos_scores
if self.debug:
logger.debug(f"Codes generation: {self.codes_count}/{self.target_codes}, forcing EOS")
return self._apply_temperature_scaling(scores)
batch_size = scores.shape[0]
# Process each sequence in batch
for b in range(batch_size):
result = self._process_single_sequence(input_ids[b], scores[b:b+1])
scores[b] = result[0] # result is [1, vocab_size], need [vocab_size]
# Apply temperature scaling after constraint masking
return self._apply_temperature_scaling(scores)
def _input_contains_think_end_tag(self, input_ids: torch.LongTensor) -> bool:
"""
Check if input contains the closing tag.
Args:
input_ids: [batch_size, seq_len] input token IDs
Returns:
True if is found in the input (any sequence in batch)
"""
# Tokenize to get its token sequence
think_end_tokens = self.tokenizer.encode("", add_special_tokens=False)
if not think_end_tokens:
return False
# Check each sequence in batch
for b in range(input_ids.shape[0]):
seq = input_ids[b].tolist()
# Search for the token sequence in the input
for i in range(len(seq) - len(think_end_tokens) + 1):
if seq[i:i+len(think_end_tokens)] == think_end_tokens:
return True
return False
def _apply_temperature_scaling(self, scores: torch.FloatTensor) -> torch.FloatTensor:
"""
Apply temperature scaling based on current generation phase.
Temperature scaling: logits = logits / temperature
- Lower temperature (< 1.0) makes distribution sharper (more deterministic)
- Higher temperature (> 1.0) makes distribution flatter (more diverse)
Args:
scores: [batch_size, vocab_size] logits
Returns:
Temperature-scaled logits
"""
# Determine which temperature to use based on current state
if self.state == FSMState.CODES_GENERATION or self.state == FSMState.COMPLETED:
temperature = self.codes_temperature
else:
temperature = self.metadata_temperature
# If no temperature is set for this phase, return scores unchanged
if temperature is None:
return scores
# Avoid division by zero
if temperature <= 0:
temperature = 1e-6
# Apply temperature scaling
return scores / temperature
def _get_user_provided_field_tokens(self, field_name: str) -> Optional[List[int]]:
"""
Get token sequence for a user-provided field (field_name + value + newline).
Uses the same tokenization logic as prefix tree building.
Args:
field_name: Field name ("bpm", "caption", "duration", "keyscale", "language", "timesignature")
Returns:
List of token IDs for the complete field, or None if field is not provided
"""
value = self.user_provided_metadata.get(field_name)
if value is None:
return None
# Build full field string with space (matching prefix tree tokenization)
field_to_prefix = {
"bpm": "bpm: ",
"caption": "caption: ",
"duration": "duration: ",
"keyscale": "keyscale: ",
"language": "language: ",
"timesignature": "timesignature: ",
"genres": "genres: ",
}
prefix = field_to_prefix[field_name]
full_text = f"{prefix}{value}\n"
# Tokenize the full field
tokens = self.tokenizer.encode(full_text, add_special_tokens=False)
# Extract only the field tokens (skip the prefix tokens that match state machine output)
# The state machine generates "field_name:" (no space), so we need to match that
prefix_for_matching = field_name + ":"
prefix_tokens = self.tokenizer.encode(prefix_for_matching, add_special_tokens=False)
# Find where prefix ends in full tokens
if len(tokens) >= len(prefix_tokens) and tokens[:len(prefix_tokens)] == prefix_tokens:
# Return tokens after prefix (field value + newline)
return tokens[len(prefix_tokens):]
else:
# Fallback: return all tokens (shouldn't happen if tokenization is consistent)
if self.debug:
logger.warning(f"Could not match prefix tokens for field {field_name}, using all tokens")
return tokens
def _process_single_sequence(
self,
input_ids: torch.LongTensor,
scores: torch.FloatTensor,
) -> torch.FloatTensor:
"""Process a single sequence and return modified scores (inplace when possible)."""
# Check if we have tokens in queue for user-provided field
# If so, inject the next token directly
if self.user_field_token_queue:
next_token = self.user_field_token_queue[0]
self._apply_whitelist_inplace(scores, [next_token])
return scores
if self.state in self.fixed_strings:
# Fixed string state: force specific tokens
fixed_str = self.fixed_strings[self.state]
allowed = self._get_allowed_tokens_for_fixed_string(fixed_str)
if allowed:
# Check if we should stop at reasoning (after tag)
# This happens when we're about to complete the tag
if self.state == FSMState.THINK_END_TAG and self.stop_at_reasoning:
# Check if the next token would complete the fixed string
remaining_chars = len(fixed_str) - self.position_in_state
# If remaining is small (<= 10 chars, which is typically 1-2 tokens), force EOS
if remaining_chars <= 10:
# Force EOS token to stop generation
if self.eos_token_id is not None:
self._apply_whitelist_inplace(scores, [self.eos_token_id])
if self.debug:
logger.debug(f"stop_at_reasoning=True: forcing EOS near end of tag (remaining: {remaining_chars} chars)")
return scores
# Apply whitelist constraint inplace
self._apply_whitelist_inplace(scores, allowed)
else:
# Position exceeds string, move to next state
# If stop_at_reasoning is True and we're transitioning from THINK_END_TAG,
# force EOS before transitioning
if self.state == FSMState.THINK_END_TAG and self.stop_at_reasoning:
# Force EOS token to stop generation
if self.eos_token_id is not None:
self._apply_whitelist_inplace(scores, [self.eos_token_id])
if self.debug:
logger.debug(f"stop_at_reasoning=True: forcing EOS after completing tag")
return scores
old_state = self.state
self._transition_to_next_state()
# Avoid infinite recursion: if we're still in a fixed_strings state, just return scores
if self.state in self.fixed_strings:
# This shouldn't happen, but if it does, just return scores to avoid recursion
if self.debug:
logger.warning(f"State transition from {old_state.name} to {self.state.name} still in fixed_strings, avoiding recursion")
return scores
# For recursion, reset scores to zero (no constraints from previous state)
scores.zero_()
return self._process_single_sequence(input_ids, scores)
elif self.state == FSMState.BPM_VALUE:
# Check if field is user-provided and we haven't started injecting yet
if self.user_provided_metadata["bpm"] is not None and not self.user_field_token_queue and not self.accumulated_token_ids:
# Initialize token queue with field value tokens (value + newline)
value = self.user_provided_metadata["bpm"]
# Tokenize " value\n" (space + value + newline) to match actual tokenization
value_text = f" {value}\n"
value_tokens = self.tokenizer.encode(value_text, add_special_tokens=False)
if value_tokens:
self.user_field_token_queue = value_tokens
self.current_user_field = "bpm"
# Inject first token
self._apply_whitelist_inplace(scores, [value_tokens[0]])
return scores
# Allow valid numeric tokens using prefix tree (supports multi-digit tokens like "120")
allowed = self._get_allowed_numeric_tokens(self.bpm_prefix_tree)
# Also allow newline if current token sequence prefix allows it
token_prefix = tuple(self.accumulated_token_ids)
if token_prefix in self.bpm_prefix_tree and self.newline_token in self.bpm_prefix_tree[token_prefix]:
allowed = allowed + [self.newline_token]
self._apply_whitelist_inplace(scores, allowed)
elif self.state == FSMState.CAPTION_VALUE:
# Caption field generation with YAML format support:
# - Allow newlines and spaces (YAML multi-line formatting)
# - Block audio codes and backticks
# - Max 512 tokens
# - Transition when model wants to generate next field (non-indented line)
# Check if field is user-provided and we haven't started injecting yet
if self.user_provided_metadata["caption"] is not None and not self.user_field_token_queue and not self.accumulated_value:
# Initialize token queue with field value tokens (value + newline)
value = self.user_provided_metadata["caption"]
value_text = f" {value}\n"
value_tokens = self.tokenizer.encode(value_text, add_special_tokens=False)
if value_tokens:
self.user_field_token_queue = value_tokens
self.current_user_field = "caption"
# Inject first token
self._apply_whitelist_inplace(scores, [value_tokens[0]])
return scores
# Check if we should transition after a newline (non-indented line = new field)
if self.caption_after_newline:
# Get top token from current scores
top_token_id = torch.argmax(scores[0]).item()
top_token_text = self.tokenizer.decode([top_token_id])
# If top token does NOT start with space/tab, it's a new field (like "duration:")
if len(top_token_text) > 0 and top_token_text[0] not in ' \t':
# Caption is ending - LM is generating next field name
# Instead of forcing state transition to DURATION_NAME (which would regenerate the field name),
# we enter a "caption_ending" mode where we allow free generation until we detect the field value
self.caption_after_newline = False
self.caption_ending = True # Start tracking field name
self.pending_field_name = "" # Reset pending field name
# Allow free generation (no constraints) so LM can generate field name naturally
return scores
else:
# It's indentation, continue caption (don't transition!)
self.caption_after_newline = False
# Continue normal caption generation
# Fall through to caption constraints below
# If caption is ending (LM generating next field name), allow free generation
# and track the field name until we see colon
if self.caption_ending:
# Allow any token (free generation)
# The field name detection will happen in update_state()
return scores
# Block backticks (code blocks) - inplace
if self.backtick_token is not None:
scores[0, self.backtick_token] = float('-inf')
# Block ALL audio code tokens (critical - these should never appear in caption)
# Use precomputed mask for O(1) performance instead of O(n) loop
if self.audio_code_mask is not None:
# Move mask to same device/dtype as scores if needed
if self.audio_code_mask.device != scores.device or self.audio_code_mask.dtype != scores.dtype:
self.audio_code_mask = self.audio_code_mask.to(device=scores.device, dtype=scores.dtype)
scores = scores + self.audio_code_mask
# Enforce 512 token limit for caption
if self.caption_token_count >= 512:
# Force end by only allowing newline
if self.newline_token is not None:
self._apply_whitelist_inplace(scores, [self.newline_token])
return scores
# Allow natural generation (with blocked audio codes and backticks)
return scores
elif self.state == FSMState.DURATION_VALUE:
# Check if field is user-provided and we haven't started injecting yet
if self.user_provided_metadata["duration"] is not None and not self.user_field_token_queue and not self.accumulated_token_ids:
# Initialize token queue with field value tokens (value + newline)
value = self.user_provided_metadata["duration"]
value_text = f" {value}\n"
value_tokens = self.tokenizer.encode(value_text, add_special_tokens=False)
if value_tokens:
self.user_field_token_queue = value_tokens
self.current_user_field = "duration"
# Inject first token
self._apply_whitelist_inplace(scores, [value_tokens[0]])
return scores
# If target_duration is set, force generate that exact value
if self.target_duration is not None:
target_str = str(int(self.target_duration))
current_pos = len(self.accumulated_value)
if current_pos < len(target_str):
# Force the next digit
next_digit = int(target_str[current_pos])
if next_digit in self.digit_tokens:
self._apply_whitelist_inplace(scores, [self.digit_tokens[next_digit]])
else:
# All digits generated, force newline
if self.newline_token:
self._apply_whitelist_inplace(scores, [self.newline_token])
else:
# Normal duration generation with range constraint
# Allow valid numeric tokens using prefix tree (supports multi-digit tokens like "60", "120")
allowed = self._get_allowed_numeric_tokens(self.duration_prefix_tree)
# Also allow newline if current token sequence prefix allows it
token_prefix = tuple(self.accumulated_token_ids)
if token_prefix in self.duration_prefix_tree and self.newline_token in self.duration_prefix_tree[token_prefix]:
allowed = allowed + [self.newline_token]
self._apply_whitelist_inplace(scores, allowed)
elif self.state == FSMState.GENRES_VALUE:
# Check if field is user-provided and we haven't started injecting yet
if self.user_provided_metadata["genres"] is not None and not self.user_field_token_queue and not self.accumulated_value:
# Initialize token queue with field value tokens (value + newline)
value = self.user_provided_metadata["genres"]
value_text = f" {value}\n"
value_tokens = self.tokenizer.encode(value_text, add_special_tokens=False)
if value_tokens:
self.user_field_token_queue = value_tokens
self.current_user_field = "genres"
# Inject first token
self._apply_whitelist_inplace(scores, [value_tokens[0]])
return scores
# Try to hot-reload genres vocab if file has changed
self._try_reload_genres_vocab()
# Get allowed tokens based on genres vocabulary
allowed = self._get_allowed_genres_tokens()
if allowed:
# Use vocabulary-constrained decoding
self._apply_whitelist_inplace(scores, allowed)
elif self.genres_vocab:
# Vocab is loaded but no valid continuation found
# Force newline to end the field
if self.newline_token:
if self.debug:
logger.debug(f"No valid genre continuation for '{self.accumulated_value}', forcing newline")
self._apply_whitelist_inplace(scores, [self.newline_token])
else:
# Fallback: no vocab loaded, use probability-based ending
if self._should_end_text_field(scores):
if self.newline_token:
self._apply_whitelist_inplace(scores, [self.newline_token])
self._transition_to_next_state()
else:
# Allow any token except newline if we don't have content yet
if not self.accumulated_value.strip():
if self.newline_token:
scores[0, self.newline_token] = float('-inf')
# Otherwise, don't constrain (fallback behavior)
elif self.state == FSMState.KEYSCALE_VALUE:
# Check if field is user-provided and we haven't started injecting yet
if self.user_provided_metadata["keyscale"] is not None and not self.user_field_token_queue and not self.accumulated_token_ids:
# Initialize token queue with field value tokens (value + newline)
value = self.user_provided_metadata["keyscale"]
value_text = f" {value}\n"
value_tokens = self.tokenizer.encode(value_text, add_special_tokens=False)
if value_tokens:
self.user_field_token_queue = value_tokens
self.current_user_field = "keyscale"
# Inject first token
self._apply_whitelist_inplace(scores, [value_tokens[0]])
return scores
# Check if current token sequence is complete (allows newline)
token_prefix = tuple(self.accumulated_token_ids)
if token_prefix in self.keyscale_prefix_tree and self.newline_token in self.keyscale_prefix_tree[token_prefix]:
# Complete keyscale, allow newline
if self.newline_token:
self._apply_whitelist_inplace(scores, [self.newline_token])
else:
# Not complete, allow valid continuation tokens
allowed = self._get_allowed_keyscale_tokens()
if allowed:
self._apply_whitelist_inplace(scores, allowed)
else:
# No valid tokens found - force newline to end field
# This handles edge cases where keyscale format is unexpected
if self.newline_token:
self._apply_whitelist_inplace(scores, [self.newline_token])
elif self.state == FSMState.LANGUAGE_VALUE:
# Language field: Use top-1 probability language (greedy selection)
# Unlike other fields, we don't use prefix tree sampling.
# Instead, we select the highest probability language at the start,
# then force generate the rest of that language code.
# Check if field is user-provided and we haven't started injecting yet
if self.user_provided_metadata["language"] is not None and not self.user_field_token_queue and not self.accumulated_token_ids:
# Initialize token queue with field value tokens (value + newline)
value = self.user_provided_metadata["language"]
value_text = f" {value}\n"
value_tokens = self.tokenizer.encode(value_text, add_special_tokens=False)
if value_tokens:
self.user_field_token_queue = value_tokens
self.current_user_field = "language"
# Inject first token
self._apply_whitelist_inplace(scores, [value_tokens[0]])
return scores
# If we haven't started generating language yet (empty accumulated_token_ids),
# select the top-1 probability token from all valid first tokens
if not self.accumulated_token_ids:
# Get all possible first tokens for all languages
empty_prefix = tuple()
if empty_prefix in self.language_prefix_tree:
candidate_tokens = list(self.language_prefix_tree[empty_prefix])
if candidate_tokens:
# Find the token with highest probability (top-1) among candidates
# Use tensor indexing to get scores of candidate tokens directly
candidate_indices = torch.tensor(candidate_tokens, device=scores.device, dtype=torch.long)
candidate_scores = scores[0, candidate_indices]
# Get the highest probability token among candidates
best_idx = torch.argmax(candidate_scores).item()
top_token_id = candidate_tokens[best_idx]
# Only allow this top-1 token, block all others
self._apply_whitelist_inplace(scores, [top_token_id])
if self.debug:
top_token_text = self.tokenizer.decode([top_token_id])
logger.debug(f"Language field: selected top-1 token {top_token_id} ({repr(top_token_text)}) from {len(candidate_tokens)} candidates")
else:
# No valid first tokens found - force newline
if self.newline_token:
self._apply_whitelist_inplace(scores, [self.newline_token])
else:
# Empty prefix not in tree - force newline
if self.newline_token:
self._apply_whitelist_inplace(scores, [self.newline_token])
else:
# We've started generating a language, continue with prefix tree constraints
# Check if current token sequence is complete (allows newline)
token_prefix = tuple(self.accumulated_token_ids)
if token_prefix in self.language_prefix_tree and self.newline_token in self.language_prefix_tree[token_prefix]:
# Complete language, allow newline
if self.newline_token:
self._apply_whitelist_inplace(scores, [self.newline_token])
else:
# Not complete, allow valid continuation tokens
allowed = self._get_allowed_language_tokens()
if allowed:
self._apply_whitelist_inplace(scores, allowed)
else:
# No valid tokens found - force newline to end field
if self.newline_token:
self._apply_whitelist_inplace(scores, [self.newline_token])
elif self.state == FSMState.TIMESIG_VALUE:
# Check if field is user-provided and we haven't started injecting yet
if self.user_provided_metadata["timesignature"] is not None and not self.user_field_token_queue and not self.accumulated_token_ids:
# Initialize token queue with field value tokens (value + newline)
value = self.user_provided_metadata["timesignature"]
value_text = f" {value}\n"
value_tokens = self.tokenizer.encode(value_text, add_special_tokens=False)
if value_tokens:
self.user_field_token_queue = value_tokens
self.current_user_field = "timesignature"
# Inject first token
self._apply_whitelist_inplace(scores, [value_tokens[0]])
return scores
# Check if current token sequence is complete (allows newline)
token_prefix = tuple(self.accumulated_token_ids)
if token_prefix in self.timesig_prefix_tree and self.newline_token in self.timesig_prefix_tree[token_prefix]:
# Complete value, allow newline
if self.newline_token:
self._apply_whitelist_inplace(scores, [self.newline_token])
else:
# Not complete, allow valid continuation tokens
allowed = self._get_allowed_timesig_tokens()
self._apply_whitelist_inplace(scores, allowed)
return scores
def _transition_to_next_state(self):
"""Transition to the next FSM state."""
if self.state in self.next_state:
old_state = self.state
next_state = self.next_state[self.state]
# Handle different cases at THINK_END_TAG based on generation phase
# NOTE: Do NOT override next_state here when stop_at_reasoning=True
# because we need to transition to the tag state first to generate ,
# then handle stop_at_reasoning in update_state() AFTER the tag is complete
if old_state == FSMState.THINK_END_TAG:
if self.generation_phase == "understand":
# Understanding mode: allow free-form lyrics after metadata
# No more constrained decoding after
next_state = FSMState.COMPLETED
if self.debug:
logger.debug(f"generation_phase='understand': allowing free-form lyrics after ")
# else: default to CODES_GENERATION (for "codes" phase) or respect stop_at_reasoning flag
self.state = next_state
self.position_in_state = 0
self.accumulated_value = "" # Legacy, kept for compatibility
self.accumulated_token_ids = [] # Reset token ID sequence for new field
self.caption_after_newline = False # Reset caption newline tracking
self.caption_token_count = 0 # Reset caption token count
self.caption_ending = False # Reset caption ending tracking
self.pending_field_name = "" # Reset pending field name
if self.debug:
logger.debug(f"FSM transition: {old_state.name} -> {self.state.name}")
def update_state(self, generated_token_id: int):
"""
Update internal state after a token has been generated.
This should be called after each token generation.
Args:
generated_token_id: The token ID that was just generated
"""
if not self.enabled:
return
if self.state == FSMState.COMPLETED:
return
if self.state == FSMState.CODES_GENERATION:
# Count generated codes for duration constraint
self.codes_count += 1
if self.debug and self.target_codes is not None:
logger.debug(f"Codes count: {self.codes_count}/{self.target_codes}")
return
# Handle user-provided field token injection
if self.user_field_token_queue:
# Verify the generated token matches the expected token from queue
expected_token = self.user_field_token_queue[0]
if generated_token_id != expected_token:
if self.debug:
logger.warning(f"Expected token {expected_token} but got {generated_token_id} for user-provided field {self.current_user_field}")
# Remove consumed token from queue
self.user_field_token_queue.pop(0)
# If queue is empty, field injection is complete, transition to next state
if not self.user_field_token_queue:
if self.debug:
logger.debug(f"Completed injection of user-provided field: {self.current_user_field}")
field_name = self.current_user_field
self.current_user_field = None
# Transition to next state (skip VALUE state since we already injected everything)
# The next state should be determined by _get_next_field_state
next_state = self._get_next_field_state(field_name)
if next_state:
old_state = self.state
self.state = next_state
self.position_in_state = 0
self.accumulated_value = ""
self.accumulated_token_ids = []
if self.debug:
logger.debug(f"FSM transition (after user field injection): {old_state.name} -> {self.state.name}")
else:
# All fields done, go to THINK_END_TAG
self._transition_to_next_state()
return
token_str = self.tokenizer.decode([generated_token_id])
if self.debug:
logger.debug(f"Generated token: {repr(token_str)} (id={generated_token_id}), state={self.state.name}")
if self.state in self.fixed_strings:
# Update position in fixed string
fixed_str = self.fixed_strings[self.state]
self.position_in_state += len(token_str)
# Check if we've completed the fixed string
if self.position_in_state >= len(fixed_str):
# Special handling for THINK_END_TAG with stop_at_reasoning
if self.state == FSMState.THINK_END_TAG and self.stop_at_reasoning:
# tag is complete, now we can stop generation
# Force transition to COMPLETED instead of CODES_GENERATION
old_state = self.state
self.state = FSMState.COMPLETED
self.position_in_state = 0
self.accumulated_value = ""
self.accumulated_token_ids = []
if self.debug:
logger.debug(f"FSM transition (stop_at_reasoning): {old_state.name} -> {self.state.name}")
else:
self._transition_to_next_state()
elif self.state in [FSMState.BPM_VALUE, FSMState.DURATION_VALUE, FSMState.TIMESIG_VALUE]:
# Accumulate numeric value using token ID sequence
if generated_token_id == self.newline_token:
old_state = self.state
self._transition_to_next_state()
# IMPORTANT: After state transition, if new state is a fixed_strings state,
# we should NOT update position_in_state with the newline token length,
# because that token belongs to the old state, not the new state.
# Return early to avoid the fixed_strings update logic below.
if self.state in self.fixed_strings:
return
else:
# Add token ID to sequence (for prefix tree lookup)
self.accumulated_token_ids.append(generated_token_id)
# Also update legacy accumulated_value for compatibility
if token_str.strip().isdigit():
self.accumulated_value += token_str.strip()
elif self.state == FSMState.GENRES_VALUE:
if generated_token_id == self.newline_token:
# Newline ends the field
self._transition_to_next_state()
# IMPORTANT: After state transition, if new state is a fixed_strings state,
# we should NOT update position_in_state with the newline token length,
# because that token belongs to the old state, not the new state.
# Return early to avoid the fixed_strings update logic below.
if self.state in self.fixed_strings:
return
else:
# Genres still uses string-based trie, so keep accumulated_value
self.accumulated_value += token_str
elif self.state == FSMState.CAPTION_VALUE:
# Track token count for 512 limit
self.caption_token_count += 1
# Accumulate caption text
self.accumulated_value += token_str
# Track if this token contains a newline (for transition detection)
# Token may be '\n' alone or combined with other chars like '.\n'
if '\n' in token_str:
# Mark that we need to check next token for field transition
self.caption_after_newline = True
else:
# Not a newline - if we were after newline and this is not space,
# transition already happened in _process_single_sequence
self.caption_after_newline = False
# If caption is ending, accumulate field name and detect field completion
if self.caption_ending:
self.pending_field_name += token_str
# Check if we've completed a field name (detected colon)
if ':' in token_str or token_str.strip() == ':':
# Extract field name (before colon)
field_name_full = self.pending_field_name.strip()
# Remove trailing colon if present
field_name = field_name_full.rstrip(':').strip().lower()
if self.debug:
logger.debug(f"Detected field name after caption: {repr(field_name)}")
# Map field name to VALUE state
field_name_to_value_state = {
"duration": FSMState.DURATION_VALUE,
"genres": FSMState.GENRES_VALUE,
"keyscale": FSMState.KEYSCALE_VALUE,
"language": FSMState.LANGUAGE_VALUE,
"timesignature": FSMState.TIMESIG_VALUE,
}
if field_name in field_name_to_value_state:
# Transition directly to the field's VALUE state
old_state = self.state
self.state = field_name_to_value_state[field_name]
self.position_in_state = 0
self.accumulated_value = ""
self.accumulated_token_ids = []
self.caption_ending = False
self.pending_field_name = ""
if self.debug:
logger.debug(f"FSM transition (caption ending): {old_state.name} -> {self.state.name}")
else:
# Unknown field name, force transition to next field
if self.debug:
logger.warning(f"Unknown field name after caption: {repr(field_name)}, forcing transition")
self.caption_ending = False
self.pending_field_name = ""
self._transition_to_next_state()
elif self.state == FSMState.KEYSCALE_VALUE:
if generated_token_id == self.newline_token:
# Newline ends the field
self._transition_to_next_state()
# IMPORTANT: After state transition, if new state is a fixed_strings state,
# we should NOT update position_in_state with the newline token length,
# because that token belongs to the old state, not the new state.
# Return early to avoid the fixed_strings update logic below.
if self.state in self.fixed_strings:
return
else:
# Add token ID to sequence (for prefix tree lookup)
self.accumulated_token_ids.append(generated_token_id)
# Also update legacy accumulated_value for compatibility
self.accumulated_value += token_str
elif self.state == FSMState.LANGUAGE_VALUE:
if generated_token_id == self.newline_token:
# Newline ends the field
self._transition_to_next_state()
if self.state in self.fixed_strings:
return
else:
# Add token ID to sequence (for prefix tree lookup)
self.accumulated_token_ids.append(generated_token_id)
# Also update legacy accumulated_value for compatibility
self.accumulated_value += token_str