|
|
| from enum import Enum, auto
|
| from typing import Optional, Dict, Any, Tuple, List, Callable, Set
|
| from loguru import logger
|
| from transformers import AutoTokenizer
|
| from transformers.generation.logits_process import LogitsProcessor
|
| import os
|
| import torch
|
| from .phase1_constants import (
|
| VALID_LANGUAGES,
|
| KEYSCALE_NOTES,
|
| KEYSCALE_ACCIDENTALS,
|
| KEYSCALE_MODES,
|
| VALID_KEYSCALES,
|
| BPM_MIN,
|
| BPM_MAX,
|
| DURATION_MIN,
|
| DURATION_MAX,
|
| VALID_TIME_SIGNATURES,
|
| )
|
|
|
|
|
|
|
|
|
|
|
| class FSMState(Enum):
|
| """Finite State Machine states for metadata generation"""
|
| THINK_TAG = auto()
|
| NEWLINE_AFTER_THINK = auto()
|
| BPM_NAME = auto()
|
| BPM_VALUE = auto()
|
| NEWLINE_AFTER_BPM = auto()
|
| CAPTION_NAME = auto()
|
| CAPTION_VALUE = auto()
|
| DURATION_NAME = auto()
|
| DURATION_VALUE = auto()
|
| NEWLINE_AFTER_DURATION = auto()
|
| GENRES_NAME = auto()
|
| GENRES_VALUE = auto()
|
| NEWLINE_AFTER_GENRES = auto()
|
| KEYSCALE_NAME = auto()
|
| KEYSCALE_VALUE = auto()
|
| NEWLINE_AFTER_KEYSCALE = auto()
|
| LANGUAGE_NAME = auto()
|
| LANGUAGE_VALUE = auto()
|
| TIMESIG_NAME = auto()
|
| TIMESIG_VALUE = auto()
|
| NEWLINE_AFTER_TIMESIG = auto()
|
| THINK_END_TAG = auto()
|
| CODES_GENERATION = auto()
|
| COMPLETED = auto()
|
|
|
|
|
| class MetadataConstrainedLogitsProcessor(LogitsProcessor):
|
| """
|
| FSM-driven LogitsProcessor that constrains generation to produce valid metadata.
|
|
|
| This processor enforces the following format:
|
| <think>
|
| bpm: [30-300]
|
| caption: [text without code blocks, ends with period + newline]
|
| duration: [10-600]
|
| keyscale: [A-G][#/â™]? [major/minor]
|
| language: [en/zh/ja/ko/es/fr/de/uk/ru/...]
|
| timesignature: [2/3/4/6]
|
| </think>
|
|
|
| It uses token masking (setting invalid token logits to -inf) to enforce constraints.
|
| For numeric fields, it uses early-blocking to prevent out-of-range values.
|
| For field transitions (e.g., end of numeric value), it compares P(newline) vs P(digit).
|
| For caption field, it blocks code blocks and newlines, and only transitions when
|
| the previous token was a period and newline has the highest probability.
|
| """
|
|
|
| def __init__(
|
| self,
|
| tokenizer: AutoTokenizer,
|
| enabled: bool = True,
|
| debug: bool = False,
|
| genres_vocab_path: Optional[str] = None,
|
| skip_genres: bool = True,
|
| max_duration: Optional[int] = None,
|
| ):
|
| """
|
| Initialize the constrained logits processor.
|
|
|
| This processor should be initialized once when loading the LLM and reused
|
| for all generations.
|
| Args:
|
| tokenizer: The tokenizer to use for encoding/decoding
|
| enabled: Whether to enable constrained decoding
|
| debug: Whether to print debug information
|
| genres_vocab_path: Path to genres vocabulary file
|
| skip_genres: Whether to skip genres field generation
|
| max_duration: Maximum duration in seconds (default: DURATION_MAX from constants)
|
| """
|
| self.tokenizer = tokenizer
|
| self.enabled = enabled
|
| self.debug = debug
|
| self.skip_genres = skip_genres
|
|
|
|
|
| self.max_duration = max_duration if max_duration is not None else DURATION_MAX
|
| self.skip_caption = False
|
| self.skip_language = False
|
| self.caption: Optional[str] = None
|
|
|
|
|
|
|
|
|
| self.user_provided_metadata: Dict[str, Optional[str]] = {
|
| "bpm": None,
|
| "caption": None,
|
| "duration": None,
|
| "keyscale": None,
|
| "language": None,
|
| "timesignature": None,
|
| "genres": None,
|
| }
|
|
|
|
|
|
|
|
|
| self.metadata_temperature: Optional[float] = None
|
| self.codes_temperature: Optional[float] = None
|
|
|
|
|
|
|
| self.target_duration: Optional[float] = None
|
| self.target_codes: Optional[int] = None
|
| self.codes_count: int = 0
|
|
|
|
|
| self.stop_at_reasoning: bool = False
|
|
|
|
|
|
|
| self.generation_phase: str = "cot"
|
|
|
|
|
| self.state = FSMState.THINK_TAG
|
| self.position_in_state = 0
|
| self.accumulated_value = ""
|
| self.accumulated_token_ids: List[int] = []
|
|
|
|
|
| self.caption_after_newline = False
|
| self.caption_token_count = 0
|
| self.caption_ending = False
|
| self.pending_field_name = ""
|
|
|
|
|
| self.user_field_token_queue: List[int] = []
|
| self.current_user_field: Optional[str] = None
|
| self._decoded_vocab_texts: Optional[List[str]] = None
|
|
|
|
|
| self._precompute_tokens()
|
|
|
|
|
| self.genres_vocab_path = genres_vocab_path or os.path.join(
|
| os.path.dirname(os.path.abspath(__file__)), "genres_vocab.txt"
|
| )
|
| self.genres_vocab: List[str] = []
|
| self.genres_vocab_mtime: float = 0.0
|
| self.genres_trie: Dict = {}
|
| self.caption_genres_trie: Dict = {}
|
| self.caption_matched_genres: List[str] = []
|
|
|
| self._char_to_tokens: Dict[str, set] = {}
|
|
|
|
|
| self._precompute_char_token_mapping()
|
|
|
|
|
|
|
| self.field_specs = {
|
| "bpm": {"min": BPM_MIN, "max": BPM_MAX},
|
| "duration": {"min": DURATION_MIN, "max": self.max_duration},
|
| "timesignature": {"valid_values": VALID_TIME_SIGNATURES},
|
| }
|
|
|
|
|
|
|
| self.valid_bpm_values = [str(v) for v in range(self.field_specs["bpm"]["min"], self.field_specs["bpm"]["max"] + 1)]
|
| self.valid_duration_values = [str(v) for v in range(self.field_specs["duration"]["min"], self.field_specs["duration"]["max"] + 1)]
|
| self.valid_timesig_values = [str(v) for v in self.field_specs["timesignature"]["valid_values"]]
|
|
|
|
|
| self.keyscale_prefix_tree = self._build_keyscale_prefix_tree()
|
|
|
|
|
|
|
|
|
| self.bpm_prefix_tree = self._build_numeric_prefix_tree(
|
| self.valid_bpm_values,
|
| context_prefix_for_matching="bpm:",
|
| context_prefix_for_tokenization="bpm: "
|
| )
|
| self.duration_prefix_tree = self._build_numeric_prefix_tree(
|
| self.valid_duration_values,
|
| context_prefix_for_matching="duration:",
|
| context_prefix_for_tokenization="duration: "
|
| )
|
| self.timesig_prefix_tree = self._build_numeric_prefix_tree(
|
| self.valid_timesig_values,
|
| context_prefix_for_matching="timesignature:",
|
| context_prefix_for_tokenization="timesignature: "
|
| )
|
|
|
|
|
| self.language_prefix_tree = self._build_language_prefix_tree()
|
|
|
| self._load_genres_vocab()
|
|
|
|
|
|
|
|
|
|
|
| self.fixed_strings = {
|
| FSMState.THINK_TAG: "<think>",
|
| FSMState.NEWLINE_AFTER_THINK: "\n",
|
| FSMState.BPM_NAME: "bpm:",
|
| FSMState.CAPTION_NAME: "caption:",
|
| FSMState.DURATION_NAME: "duration:",
|
| FSMState.GENRES_NAME: "genres:",
|
| FSMState.KEYSCALE_NAME: "keyscale:",
|
| FSMState.LANGUAGE_NAME: "language:",
|
| FSMState.TIMESIG_NAME: "timesignature:",
|
| FSMState.THINK_END_TAG: "</think>",
|
| }
|
|
|
|
|
| self._build_state_transitions()
|
|
|
| def _get_next_field_state(self, current_field: str) -> Optional[FSMState]:
|
| """
|
| Get the next field state. Always returns the next field's NAME state,
|
| even if the field is user-provided (we still need to generate the field name).
|
|
|
| Args:
|
| current_field: Current field name ("bpm", "caption", "duration", "genres", "keyscale", "language", "timesignature")
|
|
|
| Returns:
|
| Next FSMState (NAME state of next field), or THINK_END_TAG if no more fields
|
| """
|
|
|
|
|
| field_order = ["bpm", "caption", "duration", "genres", "keyscale", "language", "timesignature"]
|
| field_to_state = {
|
| "bpm": FSMState.BPM_NAME,
|
| "caption": FSMState.CAPTION_NAME,
|
| "duration": FSMState.DURATION_NAME,
|
| "genres": FSMState.GENRES_NAME,
|
| "keyscale": FSMState.KEYSCALE_NAME,
|
| "language": FSMState.LANGUAGE_NAME,
|
| "timesignature": FSMState.TIMESIG_NAME,
|
| }
|
|
|
| try:
|
| current_idx = field_order.index(current_field)
|
| except ValueError:
|
| return FSMState.THINK_END_TAG
|
|
|
|
|
| for i in range(current_idx + 1, len(field_order)):
|
| field = field_order[i]
|
|
|
|
|
| if field == "genres" and self.skip_genres:
|
| continue
|
| if field == "caption" and self.skip_caption:
|
| continue
|
| if field == "language" and self.skip_language:
|
| continue
|
|
|
|
|
| return field_to_state[field]
|
|
|
|
|
| return FSMState.THINK_END_TAG
|
|
|
| def _build_state_transitions(self):
|
| """Build state transition map based on user-provided metadata."""
|
| self.next_state = {
|
| FSMState.THINK_TAG: FSMState.NEWLINE_AFTER_THINK,
|
| FSMState.NEWLINE_AFTER_THINK: FSMState.BPM_NAME,
|
| FSMState.THINK_END_TAG: FSMState.CODES_GENERATION,
|
| FSMState.CODES_GENERATION: FSMState.COMPLETED,
|
| }
|
|
|
|
|
|
|
|
|
|
|
| self.next_state[FSMState.BPM_NAME] = FSMState.BPM_VALUE
|
| self.next_state[FSMState.BPM_VALUE] = self._get_next_field_state("bpm")
|
|
|
|
|
| if not self.skip_caption:
|
| self.next_state[FSMState.CAPTION_NAME] = FSMState.CAPTION_VALUE
|
| self.next_state[FSMState.CAPTION_VALUE] = self._get_next_field_state("caption")
|
|
|
|
|
| self.next_state[FSMState.DURATION_NAME] = FSMState.DURATION_VALUE
|
| self.next_state[FSMState.DURATION_VALUE] = self._get_next_field_state("duration")
|
|
|
|
|
| if not self.skip_genres:
|
| self.next_state[FSMState.GENRES_NAME] = FSMState.GENRES_VALUE
|
| self.next_state[FSMState.GENRES_VALUE] = self._get_next_field_state("genres")
|
|
|
|
|
| self.next_state[FSMState.KEYSCALE_NAME] = FSMState.KEYSCALE_VALUE
|
| self.next_state[FSMState.KEYSCALE_VALUE] = self._get_next_field_state("keyscale")
|
|
|
|
|
| if not self.skip_language:
|
| self.next_state[FSMState.LANGUAGE_NAME] = FSMState.LANGUAGE_VALUE
|
| self.next_state[FSMState.LANGUAGE_VALUE] = self._get_next_field_state("language")
|
|
|
|
|
| self.next_state[FSMState.TIMESIG_NAME] = FSMState.TIMESIG_VALUE
|
| self.next_state[FSMState.TIMESIG_VALUE] = FSMState.THINK_END_TAG
|
|
|
| def set_skip_genres(self, skip: bool):
|
| """Set whether to skip genres generation and rebuild state transitions."""
|
| self.skip_genres = skip
|
| self._build_state_transitions()
|
|
|
| def set_skip_caption(self, skip: bool):
|
| """Set whether to skip caption generation and rebuild state transitions."""
|
| self.skip_caption = skip
|
| self._build_state_transitions()
|
|
|
| def set_skip_language(self, skip: bool):
|
| """Set whether to skip language generation and rebuild state transitions."""
|
| self.skip_language = skip
|
| self._build_state_transitions()
|
|
|
| @staticmethod
|
| def postprocess_caption(caption: str) -> str:
|
| """
|
| Post-process caption to remove YAML multi-line formatting.
|
| Converts YAML-style multi-line text (with newlines and leading spaces)
|
| to a single-line string.
|
|
|
| Example:
|
| Input: "An emotional ballad.\\n The track opens with piano.\\n More text."
|
| Output: "An emotional ballad. The track opens with piano. More text."
|
|
|
| Args:
|
| caption: Raw caption text with possible YAML formatting
|
|
|
| Returns:
|
| Clean single-line caption
|
| """
|
| if not caption:
|
| return caption
|
|
|
|
|
| lines = caption.split('\n')
|
|
|
|
|
| cleaned_lines = []
|
| for line in lines:
|
| stripped = line.strip()
|
| if stripped:
|
| cleaned_lines.append(stripped)
|
|
|
|
|
| return ' '.join(cleaned_lines)
|
|
|
| def set_stop_at_reasoning(self, stop: bool):
|
| """
|
| Set whether to stop generation after </think> tag.
|
|
|
| Args:
|
| stop: If True, generation will stop immediately after </think> tag is generated.
|
| If False, generation continues to codes generation phase.
|
| """
|
| self.stop_at_reasoning = stop
|
|
|
| def set_generation_phase(self, phase: str):
|
| """
|
| Set the generation phase.
|
|
|
| Args:
|
| phase: "cot" for CoT metadata generation, "codes" for audio codes generation,
|
| or "understand" for audio understanding (codes → metadata + lyrics).
|
| When phase is "codes" and the input prompt already contains </think>,
|
| the FSM will skip metadata generation and go directly to codes generation.
|
| When phase is "understand", generate CoT metadata then free-form lyrics.
|
| """
|
| if phase not in ("cot", "codes", "understand"):
|
| raise ValueError(f"Invalid generation phase: {phase!r}. Must be 'cot', 'codes', or 'understand'")
|
| self.generation_phase = phase
|
|
|
| def set_user_metadata(self, metadata: Optional[Dict[str, Optional[str]]] = None):
|
| """
|
| Set user-provided metadata fields. Fields that are provided will be used directly
|
| instead of generating. Fields that are None will be generated.
|
|
|
| Args:
|
| metadata: Dictionary with optional fields:
|
| - "bpm": Optional[str] - e.g., "120"
|
| - "caption": Optional[str] - e.g., "A melodic piano piece..."
|
| - "duration": Optional[str] - e.g., "234"
|
| - "keyscale": Optional[str] - e.g., "G major"
|
| - "language": Optional[str] - e.g., "en"
|
| - "timesignature": Optional[str] - e.g., "4"
|
| - "genres": Optional[str] - e.g., "Pop Rock"
|
| If None, clears all user-provided metadata.
|
| """
|
| if metadata is None:
|
| metadata = {}
|
|
|
|
|
| for field in ["bpm", "caption", "duration", "keyscale", "language", "timesignature", "genres"]:
|
| if field in metadata:
|
| self.user_provided_metadata[field] = metadata[field]
|
| else:
|
| self.user_provided_metadata[field] = None
|
|
|
|
|
| self._build_state_transitions()
|
|
|
| if self.debug:
|
| provided_fields = [k for k, v in self.user_provided_metadata.items() if v is not None]
|
| if provided_fields:
|
| logger.debug(f"User provided metadata fields: {provided_fields}")
|
| else:
|
| logger.debug("No user-provided metadata, all fields will be generated")
|
|
|
| def _precompute_tokens(self):
|
| """Pre-compute commonly used token IDs for efficiency."""
|
|
|
| self.digit_tokens = {}
|
| for d in range(10):
|
| tokens = self.tokenizer.encode(str(d), add_special_tokens=False)
|
| if tokens:
|
| self.digit_tokens[d] = tokens[-1]
|
|
|
|
|
| newline_tokens = self.tokenizer.encode("\n", add_special_tokens=False)
|
| self.newline_token = newline_tokens[-1] if newline_tokens else None
|
|
|
|
|
| self.note_tokens = {}
|
| for note in KEYSCALE_NOTES:
|
| tokens = self.tokenizer.encode(note, add_special_tokens=False)
|
| if tokens:
|
| self.note_tokens[note] = tokens[-1]
|
|
|
|
|
| self.sharp_tokens = []
|
| for s in ["#", "♯"]:
|
| tokens = self.tokenizer.encode(s, add_special_tokens=False)
|
| if tokens:
|
| self.sharp_tokens.append(tokens[-1])
|
|
|
| self.flat_tokens = []
|
| for f in ["b", "â™"]:
|
| tokens = self.tokenizer.encode(f, add_special_tokens=False)
|
| if tokens:
|
| self.flat_tokens.append(tokens[-1])
|
|
|
|
|
| space_tokens = self.tokenizer.encode(" ", add_special_tokens=False)
|
| self.space_token = space_tokens[-1] if space_tokens else None
|
|
|
|
|
| self.major_start_tokens = []
|
| self.minor_start_tokens = []
|
| for prefix in ["m", "M"]:
|
| tokens = self.tokenizer.encode(prefix, add_special_tokens=False)
|
| if tokens:
|
| if prefix.lower() == "m":
|
| self.minor_start_tokens.append(tokens[-1])
|
| self.major_start_tokens.append(tokens[-1])
|
|
|
|
|
| self.vocab_size = len(self.tokenizer)
|
|
|
|
|
| comma_tokens = self.tokenizer.encode(",", add_special_tokens=False)
|
| self.comma_token = comma_tokens[-1] if comma_tokens else None
|
|
|
|
|
| self.eos_token_id = self.tokenizer.eos_token_id
|
|
|
|
|
| period_tokens = self.tokenizer.encode(".", add_special_tokens=False)
|
| self.period_token = period_tokens[-1] if period_tokens else None
|
|
|
|
|
| backtick_tokens = self.tokenizer.encode("`", add_special_tokens=False)
|
| self.backtick_token = backtick_tokens[-1] if backtick_tokens else None
|
|
|
|
|
| self.valid_languages = VALID_LANGUAGES
|
|
|
|
|
|
|
| self.audio_code_token_ids: Set[int] = set()
|
| self._precompute_audio_code_tokens()
|
|
|
|
|
|
|
| self.audio_code_mask: Optional[torch.Tensor] = None
|
|
|
| self.non_audio_code_mask: Optional[torch.Tensor] = None
|
| self._build_audio_code_mask()
|
|
|
|
|
|
|
| self.valid_keyscales = VALID_KEYSCALES.copy()
|
|
|
|
|
|
|
|
|
| def _precompute_audio_code_tokens(self):
|
| """
|
| Precompute audio code token IDs (tokens matching <|audio_code_\\d+|>).
|
| These tokens should be blocked during caption generation.
|
| """
|
| import re
|
| audio_code_pattern = re.compile(r'^<\|audio_code_\d+\|>$')
|
|
|
|
|
| token_texts = self._get_decoded_vocab_texts()
|
| for token_id, token_text in enumerate(token_texts):
|
| if token_text and audio_code_pattern.match(token_text):
|
| self.audio_code_token_ids.add(token_id)
|
|
|
| if self.debug:
|
| logger.debug(f"Found {len(self.audio_code_token_ids)} audio code tokens")
|
|
|
| def _get_decoded_vocab_texts(self, chunk_size: int = 8192) -> List[str]:
|
| """
|
| Decode every vocab token once, preferring batched decode for lower Python overhead.
|
| Falls back to per-token decode for chunks that fail.
|
| """
|
| if self._decoded_vocab_texts is not None:
|
| return self._decoded_vocab_texts
|
|
|
| decoded_texts = [""] * self.vocab_size
|
| fallback_chunks = 0
|
| fallback_tokens = 0
|
| first_fallback_error = None
|
| for start in range(0, self.vocab_size, chunk_size):
|
| end = min(start + chunk_size, self.vocab_size)
|
| token_batches = [[token_id] for token_id in range(start, end)]
|
| try:
|
| chunk_texts = self.tokenizer.batch_decode(token_batches)
|
| for offset, text in enumerate(chunk_texts):
|
| decoded_texts[start + offset] = text if text is not None else ""
|
| except Exception as ex:
|
| fallback_chunks += 1
|
| fallback_tokens += (end - start)
|
| if first_fallback_error is None:
|
| first_fallback_error = ex
|
| for token_id in range(start, end):
|
| try:
|
| text = self.tokenizer.decode([token_id])
|
| except Exception:
|
| text = ""
|
| decoded_texts[token_id] = text if text is not None else ""
|
|
|
| if fallback_chunks > 0:
|
| logger.warning(
|
| f"Tokenizer batch_decode fallback activated for {fallback_tokens} token(s) "
|
| f"across {fallback_chunks} chunk(s); using slower per-token decode. "
|
| f"First error: {first_fallback_error}"
|
| )
|
|
|
| self._decoded_vocab_texts = decoded_texts
|
| return decoded_texts
|
|
|
| def _build_audio_code_mask(self):
|
| """
|
| Build a precomputed mask tensor for blocking audio code tokens.
|
| This mask can be added to scores in O(1) time instead of O(n) loop.
|
|
|
| The mask is [1, vocab_size] tensor with -inf at audio code token positions.
|
|
|
| Also builds the inverse mask (non_audio_code_mask) for CODES_GENERATION state,
|
| which blocks all non-audio-code tokens.
|
| """
|
| if not self.audio_code_token_ids:
|
| self.audio_code_mask = None
|
| self.non_audio_code_mask = None
|
| return
|
|
|
|
|
|
|
| mask = torch.zeros(1, self.vocab_size, dtype=torch.float32)
|
|
|
|
|
| audio_code_indices = list(self.audio_code_token_ids)
|
|
|
|
|
| mask[0, audio_code_indices] = float('-inf')
|
|
|
| self.audio_code_mask = mask
|
|
|
|
|
|
|
| inverse_mask = torch.full((1, self.vocab_size), float('-inf'), dtype=torch.float32)
|
| inverse_mask[0, audio_code_indices] = 0
|
|
|
|
|
| if self.eos_token_id is not None:
|
| inverse_mask[0, self.eos_token_id] = 0
|
|
|
| self.non_audio_code_mask = inverse_mask
|
|
|
| if self.debug:
|
| logger.debug(f"Built audio code masks for {len(self.audio_code_token_ids)} tokens")
|
|
|
| def _apply_whitelist_inplace(self, scores: torch.Tensor, allowed_tokens: List[int]) -> None:
|
| """
|
| Apply whitelist constraint inplace: only allow specified tokens, block all others.
|
|
|
| This is more efficient than creating a mask tensor because:
|
| 1. No memory allocation for mask
|
| 2. No tensor addition operation
|
|
|
| Args:
|
| scores: [1, vocab_size] scores tensor to modify inplace
|
| allowed_tokens: List of token IDs to allow (all others will be set to -inf)
|
| """
|
| if not allowed_tokens:
|
|
|
| scores.fill_(float('-inf'))
|
| return
|
|
|
|
|
| allowed_indices = torch.tensor(allowed_tokens, device=scores.device, dtype=torch.long)
|
| saved_values = scores[0, allowed_indices].clone()
|
|
|
|
|
| scores.fill_(float('-inf'))
|
|
|
|
|
| scores[0, allowed_indices] = saved_values
|
|
|
| def _build_keyscale_prefix_tree(self) -> Dict[Tuple[int, ...], Set[int]]:
|
| """
|
| Build keyscale prefix to allowed tokens mapping based on ACTUAL tokenization.
|
|
|
| IMPORTANT: Uses token ID sequences as keys, NOT strings, to avoid tokenization mismatches.
|
|
|
| CRITICAL FIX: The tokenizer may merge the context's trailing space into the next token.
|
| For example:
|
| - "keyscale: " tokenizes to [10563, 2246, 25, 220] -> ['keys', 'cale', ':', ' ']
|
| - "keyscale: G major" tokenizes to [10563, 2246, 25, 479, 3598] -> ['keys', 'cale', ':', ' G', ' major']
|
| The space ' ' (220) is merged into ' G' (479), so we can't use simple slicing.
|
|
|
| Strategy:
|
| 1. For each keyscale (e.g., "G major"), encode the FULL string "keyscale: G major"
|
| 2. Tokenize to get: [10563, 2246, 25, 479, 3598] -> ['keys', 'cale', ':', ' G', ' major']
|
| 3. Find where context prefix ends by matching token sequences (handling space merging)
|
| 4. Extract keyscale value tokens: [479, 3598] (for "G major")
|
| 5. Build prefix tree using token ID sequences as keys
|
|
|
| This ensures we get the exact tokenization that occurs during generation.
|
| """
|
| prefix_to_tokens: Dict[Tuple[int, ...], Set[int]] = {}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| context_prefix_for_matching = "keyscale:"
|
| context_prefix_for_tokenization = "keyscale: "
|
|
|
|
|
| context_token_ids = self.tokenizer.encode(context_prefix_for_matching, add_special_tokens=False)
|
|
|
| if self.debug:
|
| context_tokens_str = [self.tokenizer.decode([t]) for t in context_token_ids]
|
| logger.debug(f"Context for matching 'keyscale:' tokenizes to {context_token_ids} -> {context_tokens_str}")
|
|
|
|
|
| for keyscale in self.valid_keyscales:
|
|
|
| full_text = context_prefix_for_tokenization + keyscale
|
| full_token_ids = self.tokenizer.encode(full_text, add_special_tokens=False)
|
|
|
|
|
|
|
|
|
| context_end_idx = None
|
|
|
|
|
| if len(full_token_ids) >= len(context_token_ids):
|
| if full_token_ids[:len(context_token_ids)] == context_token_ids:
|
| context_end_idx = len(context_token_ids)
|
|
|
| if context_end_idx is None:
|
| if self.debug:
|
| logger.warning(f"Could not find context prefix in full tokenization of '{full_text}', skipping")
|
| continue
|
|
|
|
|
| keyscale_token_ids = full_token_ids[context_end_idx:]
|
|
|
|
|
| if not keyscale_token_ids:
|
| if self.debug:
|
| logger.warning(f"No tokens extracted for keyscale '{keyscale}', skipping")
|
| continue
|
|
|
|
|
|
|
| first_token_id = keyscale_token_ids[0]
|
| first_token_str = self.tokenizer.decode([first_token_id])
|
|
|
| first_char = first_token_str.lstrip()[0].upper() if first_token_str.lstrip() else ""
|
| if first_char not in "ABCDEFG":
|
|
|
| if self.debug:
|
| logger.debug(f"Skipping keyscale '{keyscale}': first token is '{first_token_str}' (id={first_token_id}), not a note")
|
| continue
|
|
|
|
|
|
|
| for i in range(len(keyscale_token_ids) + 1):
|
|
|
| token_prefix = tuple(keyscale_token_ids[:i])
|
|
|
| if token_prefix not in prefix_to_tokens:
|
| prefix_to_tokens[token_prefix] = set()
|
|
|
| if i < len(keyscale_token_ids):
|
|
|
| next_token_id = keyscale_token_ids[i]
|
| prefix_to_tokens[token_prefix].add(next_token_id)
|
| else:
|
|
|
| if self.newline_token:
|
| prefix_to_tokens[token_prefix].add(self.newline_token)
|
|
|
| if self.debug:
|
| logger.debug(f"Built keyscale prefix tree with {len(prefix_to_tokens)} token sequence prefixes")
|
|
|
| empty_prefix = tuple()
|
| if empty_prefix in prefix_to_tokens:
|
| first_tokens = prefix_to_tokens[empty_prefix]
|
| decoded_first = [(t, repr(self.tokenizer.decode([t]))) for t in sorted(first_tokens)]
|
| logger.debug(f"First tokens allowed (empty prefix): {decoded_first}")
|
|
|
| return prefix_to_tokens
|
|
|
| def _build_numeric_prefix_tree(
|
| self,
|
| valid_values: List[str],
|
| context_prefix_for_matching: str = "",
|
| context_prefix_for_tokenization: str = ""
|
| ) -> Dict[Tuple[int, ...], Set[int]]:
|
| """
|
| Build prefix tree for numeric field based on actual tokenization with context.
|
|
|
| IMPORTANT: Uses token ID sequences as keys, NOT strings, to avoid tokenization mismatches.
|
|
|
| Args:
|
| valid_values: List of valid numeric strings (e.g., ["30", "31", ..., "300"])
|
| context_prefix_for_matching: Context string that state machine generates (e.g., "bpm:") - no space
|
| context_prefix_for_tokenization: Context string for tokenization (e.g., "bpm: ") - with space
|
|
|
| Returns:
|
| Dict mapping token ID sequence prefix -> set of allowed token IDs
|
| """
|
| prefix_to_tokens: Dict[Tuple[int, ...], Set[int]] = {}
|
|
|
|
|
| context_token_ids = self.tokenizer.encode(context_prefix_for_matching, add_special_tokens=False) if context_prefix_for_matching else []
|
|
|
|
|
| for value_str in valid_values:
|
|
|
| full_text = context_prefix_for_tokenization + value_str
|
| token_ids = self.tokenizer.encode(full_text, add_special_tokens=False)
|
|
|
|
|
| context_end_idx = None
|
| if len(token_ids) >= len(context_token_ids):
|
| if token_ids[:len(context_token_ids)] == context_token_ids:
|
| context_end_idx = len(context_token_ids)
|
|
|
| if context_end_idx is None:
|
| if self.debug:
|
| logger.warning(f"Could not find context prefix in full tokenization of '{full_text}', skipping")
|
| continue
|
|
|
|
|
| value_token_ids = token_ids[context_end_idx:]
|
|
|
|
|
| for i in range(len(value_token_ids) + 1):
|
|
|
| token_prefix = tuple(value_token_ids[:i])
|
|
|
| if token_prefix not in prefix_to_tokens:
|
| prefix_to_tokens[token_prefix] = set()
|
|
|
| if i < len(value_token_ids):
|
|
|
| next_token_id = value_token_ids[i]
|
| prefix_to_tokens[token_prefix].add(next_token_id)
|
| else:
|
|
|
| if self.newline_token:
|
| prefix_to_tokens[token_prefix].add(self.newline_token)
|
|
|
| return prefix_to_tokens
|
|
|
| def _build_language_prefix_tree(self) -> Dict[Tuple[int, ...], Set[int]]:
|
| """
|
| Build language prefix to allowed tokens mapping based on ACTUAL tokenization.
|
| Similar to keyscale prefix tree but for language codes.
|
|
|
| Uses token ID sequences as keys, NOT strings, to avoid tokenization mismatches.
|
| """
|
| prefix_to_tokens: Dict[Tuple[int, ...], Set[int]] = {}
|
|
|
| context_prefix_for_matching = "language:"
|
| context_prefix_for_tokenization = "language: "
|
|
|
| context_token_ids = self.tokenizer.encode(context_prefix_for_matching, add_special_tokens=False)
|
|
|
| if self.debug:
|
| context_tokens_str = [self.tokenizer.decode([t]) for t in context_token_ids]
|
| logger.debug(f"Context for matching 'language:' tokenizes to {context_token_ids} -> {context_tokens_str}")
|
|
|
| for lang in self.valid_languages:
|
| full_text = context_prefix_for_tokenization + lang
|
| full_token_ids = self.tokenizer.encode(full_text, add_special_tokens=False)
|
|
|
| context_end_idx = None
|
| if len(full_token_ids) >= len(context_token_ids):
|
| if full_token_ids[:len(context_token_ids)] == context_token_ids:
|
| context_end_idx = len(context_token_ids)
|
|
|
| if context_end_idx is None:
|
| if self.debug:
|
| logger.warning(f"Could not find context prefix in full tokenization of '{full_text}', skipping")
|
| continue
|
|
|
| lang_token_ids = full_token_ids[context_end_idx:]
|
|
|
| if not lang_token_ids:
|
| if self.debug:
|
| logger.warning(f"No tokens extracted for language '{lang}', skipping")
|
| continue
|
|
|
| for i in range(len(lang_token_ids) + 1):
|
| token_prefix = tuple(lang_token_ids[:i])
|
|
|
| if token_prefix not in prefix_to_tokens:
|
| prefix_to_tokens[token_prefix] = set()
|
|
|
| if i < len(lang_token_ids):
|
| next_token_id = lang_token_ids[i]
|
| prefix_to_tokens[token_prefix].add(next_token_id)
|
| else:
|
| if self.newline_token:
|
| prefix_to_tokens[token_prefix].add(self.newline_token)
|
|
|
| if self.debug:
|
| logger.debug(f"Built language prefix tree with {len(prefix_to_tokens)} token sequence prefixes")
|
| empty_prefix = tuple()
|
| if empty_prefix in prefix_to_tokens:
|
| first_tokens = prefix_to_tokens[empty_prefix]
|
| decoded_first = [(t, repr(self.tokenizer.decode([t]))) for t in sorted(first_tokens)]
|
| logger.debug(f"First tokens allowed for language (empty prefix): {decoded_first}")
|
|
|
| return prefix_to_tokens
|
|
|
| def diagnose_keyscale_prefix_tree(self):
|
| """
|
| Diagnose the keyscale prefix tree to help debug generation bias.
|
| Call this method to print detailed information about allowed tokens at each prefix.
|
| """
|
| print("=" * 60)
|
| print("KEYSCALE PREFIX TREE DIAGNOSIS")
|
| print("=" * 60)
|
|
|
|
|
| if "" in self.keyscale_prefix_tree:
|
| first_tokens = self.keyscale_prefix_tree[""]
|
| print(f"\n[Empty prefix] Allowed first tokens ({len(first_tokens)} total):")
|
| for t in sorted(first_tokens):
|
| decoded = self.tokenizer.decode([t])
|
| print(f" Token {t}: {repr(decoded)}")
|
| else:
|
| print("\nWARNING: Empty prefix not in tree!")
|
|
|
|
|
| test_prefixes = ["A", "B", "C", "D", "E", "F", "G"]
|
| for prefix in test_prefixes:
|
|
|
| for test_key in [prefix, prefix + " "]:
|
| if test_key in self.keyscale_prefix_tree:
|
| tokens = self.keyscale_prefix_tree[test_key]
|
| print(f"\n[Prefix {repr(test_key)}] Allowed tokens ({len(tokens)}):")
|
| for t in sorted(tokens):
|
| decoded = self.tokenizer.decode([t])
|
| print(f" Token {t}: {repr(decoded)}")
|
|
|
|
|
| print(f"\n[Valid keyscales] Total: {len(self.valid_keyscales)}")
|
| sample = sorted(list(self.valid_keyscales))[:10]
|
| for ks in sample:
|
| print(f" {repr(ks)}")
|
|
|
| print("=" * 60)
|
|
|
|
|
| def _load_genres_vocab(self):
|
| """
|
| Load genres vocabulary from file. Supports hot reload by checking file mtime.
|
| File format: one genre per line, lines starting with # are comments.
|
| """
|
| if not os.path.exists(self.genres_vocab_path):
|
| if self.debug:
|
| logger.debug(f"Genres vocab file not found: {self.genres_vocab_path}")
|
| return
|
|
|
| try:
|
| mtime = os.path.getmtime(self.genres_vocab_path)
|
| if mtime <= self.genres_vocab_mtime:
|
| return
|
|
|
| with open(self.genres_vocab_path, 'r', encoding='utf-8') as f:
|
| genres = []
|
| for line in f:
|
| line = line.strip()
|
| if line and not line.startswith('#'):
|
| genres.append(line.lower())
|
|
|
| self.genres_vocab = genres
|
| self.genres_vocab_mtime = mtime
|
| self._build_genres_trie()
|
|
|
| if self.debug:
|
| logger.debug(f"Loaded {len(self.genres_vocab)} genres from {self.genres_vocab_path}")
|
| except Exception as e:
|
| logger.warning(f"Failed to load genres vocab: {e}")
|
|
|
| def _build_genres_trie(self):
|
| """
|
| Build a trie (prefix tree) from genres vocabulary for efficient prefix matching.
|
| Each node is a dict with:
|
| - '_end': True if this node represents a complete genre
|
| - other keys: next characters in the trie
|
| """
|
| self.genres_trie = {}
|
|
|
| for genre in self.genres_vocab:
|
| node = self.genres_trie
|
| for char in genre:
|
| if char not in node:
|
| node[char] = {}
|
| node = node[char]
|
| node['_end'] = True
|
|
|
| if self.debug:
|
| logger.debug(f"Built genres trie with {len(self.genres_vocab)} entries")
|
|
|
| def _extract_caption_genres(self, caption: str):
|
| """
|
| Extract genres from the user's caption that match entries in the vocabulary.
|
| This creates a smaller trie for faster and more relevant genre generation.
|
|
|
| Strategy (optimized - O(words * max_genre_len) instead of O(vocab_size)):
|
| 1. Extract words/phrases from caption
|
| 2. For each word, use trie to find all vocab entries that START with this word
|
| 3. Build a separate trie from matched genres
|
| """
|
| if not caption or not self.genres_vocab:
|
| return
|
|
|
| caption_lower = caption.lower()
|
| matched_genres = set()
|
|
|
|
|
| import re
|
| words = re.split(r'[,\s\-_/\\|]+', caption_lower)
|
| words = [w.strip() for w in words if w.strip() and len(w.strip()) >= 2]
|
|
|
|
|
| for word in words:
|
|
|
| node = self._get_genres_trie_node(word)
|
| if node is not None:
|
|
|
| self._collect_complete_genres(node, word, matched_genres)
|
|
|
|
|
|
|
| genres_set = set(self.genres_vocab)
|
| for word in words:
|
| if word in genres_set:
|
| matched_genres.add(word)
|
|
|
| if not matched_genres:
|
| if self.debug:
|
| logger.debug(f"No genres matched in caption, using full vocab")
|
| return
|
|
|
|
|
| self.caption_matched_genres = list(matched_genres)
|
| self.caption_genres_trie = {}
|
|
|
| for genre in matched_genres:
|
| node = self.caption_genres_trie
|
| for char in genre:
|
| if char not in node:
|
| node[char] = {}
|
| node = node[char]
|
| node['_end'] = True
|
|
|
| if self.debug:
|
| logger.debug(f"Matched {len(matched_genres)} genres from caption: {list(matched_genres)[:5]}...")
|
|
|
| def _collect_complete_genres(self, node: Dict, prefix: str, result: set, max_depth: int = 50):
|
| """
|
| Recursively collect all complete genres under a trie node.
|
| Limited depth to avoid too many matches.
|
| """
|
| if max_depth <= 0:
|
| return
|
|
|
| if node.get('_end', False):
|
| result.add(prefix)
|
|
|
|
|
| if len(result) >= 100:
|
| return
|
|
|
| for char, child_node in node.items():
|
| if char not in ('_end', '_tokens'):
|
| self._collect_complete_genres(child_node, prefix + char, result, max_depth - 1)
|
|
|
| def _precompute_char_token_mapping(self):
|
| """
|
| Precompute mapping from characters to token IDs and token decoded texts.
|
| This allows O(1) lookup instead of calling tokenizer.encode()/decode() at runtime.
|
|
|
| Time complexity: O(vocab_size) - runs once during initialization
|
|
|
| Note: Many subword tokenizers (like Qwen) add space prefixes to tokens.
|
| We need to handle both the raw first char and the first non-space char.
|
| """
|
| self._char_to_tokens: Dict[str, set] = {}
|
| self._token_to_text: Dict[int, str] = {}
|
|
|
|
|
| for token_id, text in enumerate(self._get_decoded_vocab_texts()):
|
| if not text:
|
| continue
|
|
|
|
|
|
|
|
|
| text_lower = text.lower()
|
| if text_lower.strip():
|
| normalized_text = text_lower.rstrip()
|
| else:
|
| normalized_text = " "
|
| self._token_to_text[token_id] = normalized_text
|
|
|
|
|
| first_char = text[0].lower()
|
| if first_char not in self._char_to_tokens:
|
| self._char_to_tokens[first_char] = set()
|
| self._char_to_tokens[first_char].add(token_id)
|
|
|
|
|
|
|
| stripped_text = text.lstrip()
|
| if stripped_text and stripped_text != text:
|
| first_nonspace_char = stripped_text[0].lower()
|
| if first_nonspace_char not in self._char_to_tokens:
|
| self._char_to_tokens[first_nonspace_char] = set()
|
| self._char_to_tokens[first_nonspace_char].add(token_id)
|
|
|
|
|
| self._decoded_vocab_texts = None
|
|
|
| if self.debug:
|
| logger.debug(f"Precomputed char->token mapping for {len(self._char_to_tokens)} unique characters")
|
|
|
| def _try_reload_genres_vocab(self):
|
| """Check if genres vocab file has been updated and reload if necessary."""
|
| if not os.path.exists(self.genres_vocab_path):
|
| return
|
|
|
| try:
|
| mtime = os.path.getmtime(self.genres_vocab_path)
|
| if mtime > self.genres_vocab_mtime:
|
| self._load_genres_vocab()
|
| except Exception:
|
| pass
|
|
|
| def _get_genres_trie_node(self, prefix: str) -> Optional[Dict]:
|
| """
|
| Get the trie node for a given prefix.
|
| Returns None if the prefix is not valid (no genres start with this prefix).
|
| """
|
| node = self.genres_trie
|
| for char in prefix.lower():
|
| if char not in node:
|
| return None
|
| node = node[char]
|
| return node
|
|
|
| def _is_complete_genre(self, text: str) -> bool:
|
| """Check if the given text is a complete genre in the vocabulary."""
|
| node = self._get_genres_trie_node(text.strip())
|
| return node is not None and node.get('_end', False)
|
|
|
| def _get_trie_node_from_trie(self, trie: Dict, prefix: str) -> Optional[Dict]:
|
| """Get a trie node from a specific trie (helper for caption vs full trie)."""
|
| node = trie
|
| for char in prefix.lower():
|
| if char not in node:
|
| return None
|
| node = node[char]
|
| return node
|
|
|
| def _get_allowed_genres_tokens(self) -> List[int]:
|
| """
|
| Get allowed tokens for genres field based on trie matching.
|
|
|
| The entire genres string (including commas) must match a complete entry in the vocab.
|
| For example, if vocab contains "pop, rock, jazz", the generated string must exactly
|
| match that entry - we don't treat commas as separators for individual genres.
|
|
|
| Strategy:
|
| 1. If caption-matched genres exist, use that smaller trie first (faster + more relevant)
|
| 2. If no caption matches or prefix not in caption trie, fallback to full vocab trie
|
| 3. Get valid next characters from current trie node
|
| 4. For each candidate token, verify the full decoded text forms a valid trie prefix
|
| """
|
| if not self.genres_vocab:
|
|
|
| return []
|
|
|
|
|
| accumulated = self.accumulated_value.lower()
|
| current_genre_prefix = accumulated.strip()
|
|
|
|
|
| use_caption_trie = False
|
| current_node = None
|
|
|
|
|
| if self.caption_genres_trie:
|
| if current_genre_prefix == "":
|
| current_node = self.caption_genres_trie
|
| use_caption_trie = True
|
| else:
|
| current_node = self._get_trie_node_from_trie(self.caption_genres_trie, current_genre_prefix)
|
| if current_node is not None:
|
| use_caption_trie = True
|
|
|
|
|
| if current_node is None:
|
| if current_genre_prefix == "":
|
| current_node = self.genres_trie
|
| else:
|
| current_node = self._get_genres_trie_node(current_genre_prefix)
|
|
|
| if current_node is None:
|
|
|
| if self.newline_token:
|
| return [self.newline_token]
|
| return []
|
|
|
|
|
| valid_next_chars = set(k for k in current_node.keys() if k not in ('_end', '_tokens'))
|
|
|
|
|
| is_complete = current_node.get('_end', False)
|
|
|
| if not valid_next_chars:
|
|
|
| allowed = set()
|
| if is_complete and self.newline_token:
|
| allowed.add(self.newline_token)
|
| return list(allowed)
|
|
|
|
|
| candidate_tokens = set()
|
| for char in valid_next_chars:
|
| if char in self._char_to_tokens:
|
| candidate_tokens.update(self._char_to_tokens[char])
|
|
|
|
|
| active_trie = self.caption_genres_trie if use_caption_trie else self.genres_trie
|
|
|
|
|
| allowed = set()
|
| for token_id in candidate_tokens:
|
|
|
| decoded_normalized = self._token_to_text.get(token_id, "")
|
|
|
| if not decoded_normalized or not decoded_normalized.strip():
|
|
|
| if ' ' in valid_next_chars or ',' in valid_next_chars:
|
| allowed.add(token_id)
|
| continue
|
|
|
|
|
|
|
| if decoded_normalized.startswith(' ') or decoded_normalized.startswith(','):
|
|
|
| new_prefix = current_genre_prefix + decoded_normalized
|
| else:
|
| new_prefix = current_genre_prefix + decoded_normalized
|
|
|
|
|
| new_node = self._get_trie_node_from_trie(active_trie, new_prefix)
|
| if new_node is not None:
|
| allowed.add(token_id)
|
|
|
|
|
| if is_complete and self.newline_token:
|
| allowed.add(self.newline_token)
|
|
|
| return list(allowed)
|
|
|
| def reset(self):
|
| """Reset the processor state for a new generation."""
|
| self.state = FSMState.THINK_TAG
|
| self.position_in_state = 0
|
| self.accumulated_value = ""
|
| self.accumulated_token_ids = []
|
| self.codes_count = 0
|
| self.user_field_token_queue = []
|
| self.current_user_field = None
|
| self.caption_after_newline = False
|
| self.caption_token_count = 0
|
| self.caption_ending = False
|
| self.pending_field_name = ""
|
|
|
| def set_target_duration(self, duration: Optional[float]):
|
| """
|
| Set the target duration for codes generation.
|
|
|
| Args:
|
| duration: Target duration in seconds. If None, no duration constraint is applied.
|
| 5 codes = 1 second, so target_codes = duration * 5.
|
| """
|
| self.target_duration = duration
|
| if duration is not None and duration > 0:
|
| self.target_codes = int(duration * 5)
|
| if self.debug:
|
| logger.debug(f"Set target duration: {duration}s -> {self.target_codes} codes")
|
| else:
|
| self.target_codes = None
|
| if self.debug:
|
| logger.debug("Target duration cleared, no duration constraint")
|
|
|
| def set_max_duration(self, max_duration: int):
|
| """
|
| Dynamically update the maximum allowed duration for constrained decoding.
|
|
|
| This method should be called when GPU configuration changes (e.g., LM initialization state changes).
|
| It rebuilds the duration prefix tree to constrain duration values to the new maximum.
|
|
|
| Args:
|
| max_duration: Maximum duration in seconds (e.g., 120 for 2 minutes, 360 for 6 minutes)
|
| """
|
| if max_duration == self.max_duration:
|
| return
|
|
|
| old_max = self.max_duration
|
| self.max_duration = max_duration
|
|
|
|
|
| self.field_specs["duration"]["max"] = max_duration
|
|
|
|
|
| self.valid_duration_values = [str(v) for v in range(self.field_specs["duration"]["min"], self.field_specs["duration"]["max"] + 1)]
|
|
|
|
|
| self.duration_prefix_tree = self._build_numeric_prefix_tree(
|
| self.valid_duration_values,
|
| context_prefix_for_matching="duration:",
|
| context_prefix_for_tokenization="duration: "
|
| )
|
|
|
| if self.debug:
|
| logger.debug(f"Updated max duration: {old_max}s -> {max_duration}s, rebuilt prefix tree with {len(self.valid_duration_values)} values")
|
|
|
| def _get_allowed_tokens_for_fixed_string(self, fixed_str: str) -> List[int]:
|
| """
|
| Get the token IDs that can continue the fixed string from current position.
|
| Returns list of allowed token IDs.
|
|
|
| Strategy: Find the longest prefix that encodes to a single token, and return that token.
|
| This ensures we generate by tokens, not character-by-character.
|
| """
|
| remaining = fixed_str[self.position_in_state:]
|
| if not remaining:
|
| return []
|
|
|
| if self.debug:
|
| logger.debug(f"_get_allowed_tokens_for_fixed_string: fixed_str={repr(fixed_str)}, position_in_state={self.position_in_state}, remaining={repr(remaining)}")
|
|
|
|
|
|
|
| best_token = None
|
| best_prefix_len = 0
|
|
|
|
|
| for end in range(len(remaining), 0, -1):
|
| prefix = remaining[:end]
|
| tokens = self.tokenizer.encode(prefix, add_special_tokens=False)
|
| if tokens and len(tokens) == 1:
|
|
|
|
|
| best_token = tokens[0]
|
| best_prefix_len = end
|
| if self.debug:
|
| logger.debug(f"Found single-token match: prefix={repr(prefix)}, token_id={best_token}, token_text={repr(self.tokenizer.decode([best_token]))}")
|
| break
|
|
|
|
|
| if best_token is not None:
|
| return [best_token]
|
|
|
|
|
|
|
|
|
|
|
|
|
| allowed_tokens = {}
|
| for end in range(1, min(len(remaining) + 1, 20)):
|
| prefix = remaining[:end]
|
| tokens = self.tokenizer.encode(prefix, add_special_tokens=False)
|
| if tokens:
|
| first_token = tokens[0]
|
|
|
| decoded_token = self.tokenizer.decode([first_token])
|
|
|
| normalized_prefix = prefix.lstrip().lower()
|
| normalized_decoded = decoded_token.lstrip().lower()
|
|
|
|
|
| if normalized_decoded.startswith(normalized_prefix) or normalized_prefix.startswith(normalized_decoded):
|
|
|
| if first_token not in allowed_tokens or end > allowed_tokens[first_token]:
|
| allowed_tokens[first_token] = end
|
|
|
|
|
|
|
| sorted_tokens = sorted(allowed_tokens.items(), key=lambda x: x[1], reverse=True)
|
| result = [token for token, _ in sorted_tokens] if sorted_tokens else []
|
|
|
| if self.debug:
|
| logger.debug(f"Fallback: returning {len(result)} tokens: {[(t, repr(self.tokenizer.decode([t]))) for t in result[:5]]}")
|
| if result:
|
| logger.debug(f"Fixed string: {repr(fixed_str)}, position: {self.position_in_state}, remaining: {repr(remaining)}")
|
|
|
| return result
|
|
|
| def _get_allowed_digit_tokens(self, min_val: int, max_val: int) -> List[int]:
|
| """
|
| Get allowed digit tokens based on accumulated value and range constraints.
|
| Uses early-blocking to prevent out-of-range values.
|
| """
|
| if not self.accumulated_value:
|
|
|
| allowed_digits = set()
|
| for v in range(min_val, max_val + 1):
|
| allowed_digits.add(int(str(v)[0]))
|
| return [self.digit_tokens[d] for d in allowed_digits if d in self.digit_tokens]
|
|
|
| current = int(self.accumulated_value)
|
| allowed = []
|
|
|
| for d in range(10):
|
| new_value = int(self.accumulated_value + str(d))
|
|
|
|
|
|
|
|
|
|
|
|
|
| if new_value > max_val:
|
| continue
|
|
|
|
|
|
|
|
|
| if new_value >= min_val:
|
| allowed.append(d)
|
| elif new_value * 10 <= max_val:
|
|
|
| allowed.append(d)
|
|
|
| return [self.digit_tokens[d] for d in allowed if d in self.digit_tokens]
|
|
|
| def _get_allowed_numeric_tokens(self, prefix_tree: Dict[Tuple[int, ...], Set[int]]) -> List[int]:
|
| """
|
| Get allowed tokens for numeric field using the precomputed prefix tree.
|
|
|
| IMPORTANT: Uses token ID sequence as key (not string) to avoid tokenization mismatches.
|
|
|
| Args:
|
| prefix_tree: Precomputed prefix tree mapping token ID sequence -> set of allowed token IDs
|
|
|
| Returns:
|
| List of allowed token IDs for current accumulated_token_ids
|
| """
|
| token_prefix = tuple(self.accumulated_token_ids)
|
|
|
| if token_prefix in prefix_tree:
|
| return list(prefix_tree[token_prefix])
|
|
|
|
|
|
|
| return []
|
|
|
| def _should_end_numeric_field(self, logits: torch.Tensor, min_val: int, max_val: int) -> bool:
|
| """
|
| Determine if we should end the current numeric field.
|
| Returns True if P(newline) > P(any valid digit) AND current value is valid.
|
| """
|
| if not self.accumulated_value:
|
| return False
|
|
|
| current = int(self.accumulated_value)
|
| if current < min_val or current > max_val:
|
| return False
|
|
|
|
|
| probs = torch.softmax(logits, dim=-1)
|
|
|
| newline_prob = probs[0, self.newline_token].item() if self.newline_token else 0
|
|
|
|
|
| allowed_digits = self._get_allowed_digit_tokens(min_val, max_val)
|
| if not allowed_digits:
|
| return True
|
|
|
| max_digit_prob = max(probs[0, t].item() for t in allowed_digits)
|
|
|
| if self.debug:
|
| logger.debug(f"Numeric field decision: newline_prob={newline_prob:.4f}, max_digit_prob={max_digit_prob:.4f}")
|
|
|
| return newline_prob > max_digit_prob
|
|
|
|
|
| def _should_end_text_field(self, logits: torch.Tensor) -> bool:
|
| """
|
| Determine if we should end a text field (genres).
|
| Returns True if P(newline) > P(any other token) AND we have some content.
|
| """
|
| if not self.accumulated_value.strip():
|
| return False
|
|
|
| probs = torch.softmax(logits, dim=-1)
|
| newline_prob = probs[0, self.newline_token].item() if self.newline_token else 0
|
|
|
|
|
| masked_probs = probs.clone()
|
| if self.newline_token:
|
| masked_probs[0, self.newline_token] = 0
|
| max_other_prob = masked_probs[0].max().item()
|
|
|
| return newline_prob > max_other_prob
|
|
|
| def _get_allowed_keyscale_tokens(self) -> List[int]:
|
| """
|
| Get allowed tokens for keyscale field using the precomputed prefix tree.
|
| Uses token ID sequence as key (not string) to avoid tokenization mismatches.
|
| """
|
|
|
| token_prefix = tuple(self.accumulated_token_ids)
|
|
|
| if token_prefix in self.keyscale_prefix_tree:
|
| return list(self.keyscale_prefix_tree[token_prefix])
|
|
|
|
|
|
|
| return []
|
|
|
| def _is_keyscale_complete(self) -> bool:
|
| """
|
| Check if keyscale value is complete and valid.
|
| Uses token ID sequence to check if current prefix allows newline.
|
| """
|
| token_prefix = tuple(self.accumulated_token_ids)
|
|
|
| if token_prefix in self.keyscale_prefix_tree:
|
| return self.newline_token in self.keyscale_prefix_tree[token_prefix]
|
| return False
|
|
|
| def _get_allowed_language_tokens(self) -> List[int]:
|
| """
|
| Get allowed tokens for language field using the precomputed prefix tree.
|
| Uses token ID sequence as key (not string) to avoid tokenization mismatches.
|
| Similar to keyscale.
|
| """
|
| token_prefix = tuple(self.accumulated_token_ids)
|
|
|
| if token_prefix in self.language_prefix_tree:
|
| return list(self.language_prefix_tree[token_prefix])
|
|
|
|
|
| return []
|
|
|
| def _get_allowed_timesig_tokens(self) -> List[int]:
|
| """
|
| Get allowed tokens for timesignature field using the precomputed prefix tree.
|
| Uses token ID sequence as key (not string) to avoid tokenization mismatches.
|
| """
|
| token_prefix = tuple(self.accumulated_token_ids)
|
|
|
| if token_prefix in self.timesig_prefix_tree:
|
| return list(self.timesig_prefix_tree[token_prefix])
|
|
|
|
|
|
|
| return []
|
|
|
| def __call__(
|
| self,
|
| input_ids: torch.LongTensor,
|
| scores: torch.FloatTensor,
|
| ) -> torch.FloatTensor:
|
| """
|
| Apply constrained decoding by modifying logits.
|
|
|
| Args:
|
| input_ids: [batch_size, seq_len] input token IDs
|
| scores: [batch_size, vocab_size] logits for next token
|
|
|
| Returns:
|
| Modified scores with invalid tokens masked to -inf and temperature scaling applied
|
| """
|
| if not self.enabled:
|
| return self._apply_temperature_scaling(scores)
|
|
|
| if self.state == FSMState.COMPLETED:
|
|
|
| if self.generation_phase == "understand" and self.audio_code_mask is not None:
|
|
|
| if self.audio_code_mask.device != scores.device or self.audio_code_mask.dtype != scores.dtype:
|
| self.audio_code_mask = self.audio_code_mask.to(device=scores.device, dtype=scores.dtype)
|
| scores = scores + self.audio_code_mask
|
| return self._apply_temperature_scaling(scores)
|
|
|
|
|
| if self.generation_phase == "codes" and self.state == FSMState.THINK_TAG:
|
|
|
| if self._input_contains_think_end_tag(input_ids):
|
|
|
| self.state = FSMState.CODES_GENERATION
|
| self.codes_count = 0
|
| if self.debug:
|
| logger.debug("Codes phase: detected </think> in input, skipping to CODES_GENERATION")
|
|
|
| if self.state == FSMState.CODES_GENERATION:
|
|
|
| if self.non_audio_code_mask is not None:
|
|
|
| if self.non_audio_code_mask.device != scores.device or self.non_audio_code_mask.dtype != scores.dtype:
|
| self.non_audio_code_mask = self.non_audio_code_mask.to(device=scores.device, dtype=scores.dtype)
|
| scores = scores + self.non_audio_code_mask
|
|
|
|
|
| if self.target_codes is not None and self.eos_token_id is not None:
|
| if self.codes_count < self.target_codes:
|
|
|
| scores[:, self.eos_token_id] = float('-inf')
|
| if self.debug:
|
| logger.debug(f"Codes generation: {self.codes_count}/{self.target_codes}, blocking EOS")
|
| else:
|
|
|
| eos_scores = scores[:, self.eos_token_id].clone()
|
| scores.fill_(float('-inf'))
|
| scores[:, self.eos_token_id] = eos_scores
|
| if self.debug:
|
| logger.debug(f"Codes generation: {self.codes_count}/{self.target_codes}, forcing EOS")
|
| return self._apply_temperature_scaling(scores)
|
|
|
| batch_size = scores.shape[0]
|
|
|
|
|
| for b in range(batch_size):
|
| result = self._process_single_sequence(input_ids[b], scores[b:b+1])
|
| scores[b] = result[0]
|
|
|
|
|
| return self._apply_temperature_scaling(scores)
|
|
|
| def _input_contains_think_end_tag(self, input_ids: torch.LongTensor) -> bool:
|
| """
|
| Check if input contains the </think> closing tag.
|
|
|
| Args:
|
| input_ids: [batch_size, seq_len] input token IDs
|
|
|
| Returns:
|
| True if </think> is found in the input (any sequence in batch)
|
| """
|
|
|
| think_end_tokens = self.tokenizer.encode("</think>", add_special_tokens=False)
|
| if not think_end_tokens:
|
| return False
|
|
|
|
|
| for b in range(input_ids.shape[0]):
|
| seq = input_ids[b].tolist()
|
|
|
| for i in range(len(seq) - len(think_end_tokens) + 1):
|
| if seq[i:i+len(think_end_tokens)] == think_end_tokens:
|
| return True
|
|
|
| return False
|
|
|
| def _apply_temperature_scaling(self, scores: torch.FloatTensor) -> torch.FloatTensor:
|
| """
|
| Apply temperature scaling based on current generation phase.
|
|
|
| Temperature scaling: logits = logits / temperature
|
| - Lower temperature (< 1.0) makes distribution sharper (more deterministic)
|
| - Higher temperature (> 1.0) makes distribution flatter (more diverse)
|
|
|
| Args:
|
| scores: [batch_size, vocab_size] logits
|
|
|
| Returns:
|
| Temperature-scaled logits
|
| """
|
|
|
| if self.state == FSMState.CODES_GENERATION or self.state == FSMState.COMPLETED:
|
| temperature = self.codes_temperature
|
| else:
|
| temperature = self.metadata_temperature
|
|
|
|
|
| if temperature is None:
|
| return scores
|
|
|
|
|
| if temperature <= 0:
|
| temperature = 1e-6
|
|
|
|
|
| return scores / temperature
|
|
|
| def _get_user_provided_field_tokens(self, field_name: str) -> Optional[List[int]]:
|
| """
|
| Get token sequence for a user-provided field (field_name + value + newline).
|
| Uses the same tokenization logic as prefix tree building.
|
|
|
| Args:
|
| field_name: Field name ("bpm", "caption", "duration", "keyscale", "language", "timesignature")
|
|
|
| Returns:
|
| List of token IDs for the complete field, or None if field is not provided
|
| """
|
| value = self.user_provided_metadata.get(field_name)
|
| if value is None:
|
| return None
|
|
|
|
|
| field_to_prefix = {
|
| "bpm": "bpm: ",
|
| "caption": "caption: ",
|
| "duration": "duration: ",
|
| "keyscale": "keyscale: ",
|
| "language": "language: ",
|
| "timesignature": "timesignature: ",
|
| "genres": "genres: ",
|
| }
|
| prefix = field_to_prefix[field_name]
|
| full_text = f"{prefix}{value}\n"
|
|
|
|
|
| tokens = self.tokenizer.encode(full_text, add_special_tokens=False)
|
|
|
|
|
|
|
| prefix_for_matching = field_name + ":"
|
| prefix_tokens = self.tokenizer.encode(prefix_for_matching, add_special_tokens=False)
|
|
|
|
|
| if len(tokens) >= len(prefix_tokens) and tokens[:len(prefix_tokens)] == prefix_tokens:
|
|
|
| return tokens[len(prefix_tokens):]
|
| else:
|
|
|
| if self.debug:
|
| logger.warning(f"Could not match prefix tokens for field {field_name}, using all tokens")
|
| return tokens
|
|
|
| def _process_single_sequence(
|
| self,
|
| input_ids: torch.LongTensor,
|
| scores: torch.FloatTensor,
|
| ) -> torch.FloatTensor:
|
| """Process a single sequence and return modified scores (inplace when possible)."""
|
|
|
|
|
|
|
| if self.user_field_token_queue:
|
| next_token = self.user_field_token_queue[0]
|
| self._apply_whitelist_inplace(scores, [next_token])
|
| return scores
|
|
|
| if self.state in self.fixed_strings:
|
|
|
| fixed_str = self.fixed_strings[self.state]
|
| allowed = self._get_allowed_tokens_for_fixed_string(fixed_str)
|
|
|
| if allowed:
|
|
|
|
|
| if self.state == FSMState.THINK_END_TAG and self.stop_at_reasoning:
|
|
|
| remaining_chars = len(fixed_str) - self.position_in_state
|
|
|
| if remaining_chars <= 10:
|
|
|
| if self.eos_token_id is not None:
|
| self._apply_whitelist_inplace(scores, [self.eos_token_id])
|
| if self.debug:
|
| logger.debug(f"stop_at_reasoning=True: forcing EOS near end of </think> tag (remaining: {remaining_chars} chars)")
|
| return scores
|
|
|
|
|
| self._apply_whitelist_inplace(scores, allowed)
|
| else:
|
|
|
|
|
|
|
| if self.state == FSMState.THINK_END_TAG and self.stop_at_reasoning:
|
|
|
| if self.eos_token_id is not None:
|
| self._apply_whitelist_inplace(scores, [self.eos_token_id])
|
| if self.debug:
|
| logger.debug(f"stop_at_reasoning=True: forcing EOS after completing </think> tag")
|
| return scores
|
|
|
| old_state = self.state
|
| self._transition_to_next_state()
|
|
|
| if self.state in self.fixed_strings:
|
|
|
| if self.debug:
|
| logger.warning(f"State transition from {old_state.name} to {self.state.name} still in fixed_strings, avoiding recursion")
|
| return scores
|
|
|
| scores.zero_()
|
| return self._process_single_sequence(input_ids, scores)
|
|
|
| elif self.state == FSMState.BPM_VALUE:
|
|
|
| if self.user_provided_metadata["bpm"] is not None and not self.user_field_token_queue and not self.accumulated_token_ids:
|
|
|
| value = self.user_provided_metadata["bpm"]
|
|
|
| value_text = f" {value}\n"
|
| value_tokens = self.tokenizer.encode(value_text, add_special_tokens=False)
|
| if value_tokens:
|
| self.user_field_token_queue = value_tokens
|
| self.current_user_field = "bpm"
|
|
|
| self._apply_whitelist_inplace(scores, [value_tokens[0]])
|
| return scores
|
|
|
|
|
| allowed = self._get_allowed_numeric_tokens(self.bpm_prefix_tree)
|
|
|
|
|
| token_prefix = tuple(self.accumulated_token_ids)
|
| if token_prefix in self.bpm_prefix_tree and self.newline_token in self.bpm_prefix_tree[token_prefix]:
|
| allowed = allowed + [self.newline_token]
|
|
|
| self._apply_whitelist_inplace(scores, allowed)
|
|
|
| elif self.state == FSMState.CAPTION_VALUE:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| if self.user_provided_metadata["caption"] is not None and not self.user_field_token_queue and not self.accumulated_value:
|
|
|
| value = self.user_provided_metadata["caption"]
|
| value_text = f" {value}\n"
|
| value_tokens = self.tokenizer.encode(value_text, add_special_tokens=False)
|
| if value_tokens:
|
| self.user_field_token_queue = value_tokens
|
| self.current_user_field = "caption"
|
|
|
| self._apply_whitelist_inplace(scores, [value_tokens[0]])
|
| return scores
|
|
|
|
|
| if self.caption_after_newline:
|
|
|
| top_token_id = torch.argmax(scores[0]).item()
|
| top_token_text = self.tokenizer.decode([top_token_id])
|
|
|
|
|
| if len(top_token_text) > 0 and top_token_text[0] not in ' \t':
|
|
|
|
|
|
|
| self.caption_after_newline = False
|
| self.caption_ending = True
|
| self.pending_field_name = ""
|
|
|
| return scores
|
| else:
|
|
|
| self.caption_after_newline = False
|
|
|
|
|
|
|
|
|
|
|
| if self.caption_ending:
|
|
|
|
|
| return scores
|
|
|
|
|
| if self.backtick_token is not None:
|
| scores[0, self.backtick_token] = float('-inf')
|
|
|
|
|
|
|
| if self.audio_code_mask is not None:
|
|
|
| if self.audio_code_mask.device != scores.device or self.audio_code_mask.dtype != scores.dtype:
|
| self.audio_code_mask = self.audio_code_mask.to(device=scores.device, dtype=scores.dtype)
|
| scores = scores + self.audio_code_mask
|
|
|
|
|
| if self.caption_token_count >= 512:
|
|
|
| if self.newline_token is not None:
|
| self._apply_whitelist_inplace(scores, [self.newline_token])
|
| return scores
|
|
|
|
|
| return scores
|
|
|
| elif self.state == FSMState.DURATION_VALUE:
|
|
|
| if self.user_provided_metadata["duration"] is not None and not self.user_field_token_queue and not self.accumulated_token_ids:
|
|
|
| value = self.user_provided_metadata["duration"]
|
| value_text = f" {value}\n"
|
| value_tokens = self.tokenizer.encode(value_text, add_special_tokens=False)
|
| if value_tokens:
|
| self.user_field_token_queue = value_tokens
|
| self.current_user_field = "duration"
|
|
|
| self._apply_whitelist_inplace(scores, [value_tokens[0]])
|
| return scores
|
|
|
|
|
| if self.target_duration is not None:
|
| target_str = str(int(self.target_duration))
|
| current_pos = len(self.accumulated_value)
|
|
|
| if current_pos < len(target_str):
|
|
|
| next_digit = int(target_str[current_pos])
|
| if next_digit in self.digit_tokens:
|
| self._apply_whitelist_inplace(scores, [self.digit_tokens[next_digit]])
|
| else:
|
|
|
| if self.newline_token:
|
| self._apply_whitelist_inplace(scores, [self.newline_token])
|
| else:
|
|
|
|
|
| allowed = self._get_allowed_numeric_tokens(self.duration_prefix_tree)
|
|
|
|
|
| token_prefix = tuple(self.accumulated_token_ids)
|
| if token_prefix in self.duration_prefix_tree and self.newline_token in self.duration_prefix_tree[token_prefix]:
|
| allowed = allowed + [self.newline_token]
|
|
|
| self._apply_whitelist_inplace(scores, allowed)
|
|
|
| elif self.state == FSMState.GENRES_VALUE:
|
|
|
| if self.user_provided_metadata["genres"] is not None and not self.user_field_token_queue and not self.accumulated_value:
|
|
|
| value = self.user_provided_metadata["genres"]
|
| value_text = f" {value}\n"
|
| value_tokens = self.tokenizer.encode(value_text, add_special_tokens=False)
|
| if value_tokens:
|
| self.user_field_token_queue = value_tokens
|
| self.current_user_field = "genres"
|
|
|
| self._apply_whitelist_inplace(scores, [value_tokens[0]])
|
| return scores
|
|
|
|
|
| self._try_reload_genres_vocab()
|
|
|
|
|
| allowed = self._get_allowed_genres_tokens()
|
|
|
| if allowed:
|
|
|
| self._apply_whitelist_inplace(scores, allowed)
|
| elif self.genres_vocab:
|
|
|
|
|
| if self.newline_token:
|
| if self.debug:
|
| logger.debug(f"No valid genre continuation for '{self.accumulated_value}', forcing newline")
|
| self._apply_whitelist_inplace(scores, [self.newline_token])
|
| else:
|
|
|
| if self._should_end_text_field(scores):
|
| if self.newline_token:
|
| self._apply_whitelist_inplace(scores, [self.newline_token])
|
| self._transition_to_next_state()
|
| else:
|
|
|
| if not self.accumulated_value.strip():
|
| if self.newline_token:
|
| scores[0, self.newline_token] = float('-inf')
|
|
|
|
|
| elif self.state == FSMState.KEYSCALE_VALUE:
|
|
|
| if self.user_provided_metadata["keyscale"] is not None and not self.user_field_token_queue and not self.accumulated_token_ids:
|
|
|
| value = self.user_provided_metadata["keyscale"]
|
| value_text = f" {value}\n"
|
| value_tokens = self.tokenizer.encode(value_text, add_special_tokens=False)
|
| if value_tokens:
|
| self.user_field_token_queue = value_tokens
|
| self.current_user_field = "keyscale"
|
|
|
| self._apply_whitelist_inplace(scores, [value_tokens[0]])
|
| return scores
|
|
|
|
|
| token_prefix = tuple(self.accumulated_token_ids)
|
| if token_prefix in self.keyscale_prefix_tree and self.newline_token in self.keyscale_prefix_tree[token_prefix]:
|
|
|
| if self.newline_token:
|
| self._apply_whitelist_inplace(scores, [self.newline_token])
|
| else:
|
|
|
| allowed = self._get_allowed_keyscale_tokens()
|
| if allowed:
|
| self._apply_whitelist_inplace(scores, allowed)
|
| else:
|
|
|
|
|
| if self.newline_token:
|
| self._apply_whitelist_inplace(scores, [self.newline_token])
|
|
|
| elif self.state == FSMState.LANGUAGE_VALUE:
|
|
|
|
|
|
|
|
|
|
|
|
|
| if self.user_provided_metadata["language"] is not None and not self.user_field_token_queue and not self.accumulated_token_ids:
|
|
|
| value = self.user_provided_metadata["language"]
|
| value_text = f" {value}\n"
|
| value_tokens = self.tokenizer.encode(value_text, add_special_tokens=False)
|
| if value_tokens:
|
| self.user_field_token_queue = value_tokens
|
| self.current_user_field = "language"
|
|
|
| self._apply_whitelist_inplace(scores, [value_tokens[0]])
|
| return scores
|
|
|
|
|
|
|
| if not self.accumulated_token_ids:
|
|
|
| empty_prefix = tuple()
|
| if empty_prefix in self.language_prefix_tree:
|
| candidate_tokens = list(self.language_prefix_tree[empty_prefix])
|
|
|
| if candidate_tokens:
|
|
|
|
|
| candidate_indices = torch.tensor(candidate_tokens, device=scores.device, dtype=torch.long)
|
| candidate_scores = scores[0, candidate_indices]
|
|
|
|
|
| best_idx = torch.argmax(candidate_scores).item()
|
| top_token_id = candidate_tokens[best_idx]
|
|
|
|
|
| self._apply_whitelist_inplace(scores, [top_token_id])
|
|
|
| if self.debug:
|
| top_token_text = self.tokenizer.decode([top_token_id])
|
| logger.debug(f"Language field: selected top-1 token {top_token_id} ({repr(top_token_text)}) from {len(candidate_tokens)} candidates")
|
| else:
|
|
|
| if self.newline_token:
|
| self._apply_whitelist_inplace(scores, [self.newline_token])
|
| else:
|
|
|
| if self.newline_token:
|
| self._apply_whitelist_inplace(scores, [self.newline_token])
|
| else:
|
|
|
|
|
| token_prefix = tuple(self.accumulated_token_ids)
|
| if token_prefix in self.language_prefix_tree and self.newline_token in self.language_prefix_tree[token_prefix]:
|
|
|
| if self.newline_token:
|
| self._apply_whitelist_inplace(scores, [self.newline_token])
|
| else:
|
|
|
| allowed = self._get_allowed_language_tokens()
|
| if allowed:
|
| self._apply_whitelist_inplace(scores, allowed)
|
| else:
|
|
|
| if self.newline_token:
|
| self._apply_whitelist_inplace(scores, [self.newline_token])
|
|
|
| elif self.state == FSMState.TIMESIG_VALUE:
|
|
|
| if self.user_provided_metadata["timesignature"] is not None and not self.user_field_token_queue and not self.accumulated_token_ids:
|
|
|
| value = self.user_provided_metadata["timesignature"]
|
| value_text = f" {value}\n"
|
| value_tokens = self.tokenizer.encode(value_text, add_special_tokens=False)
|
| if value_tokens:
|
| self.user_field_token_queue = value_tokens
|
| self.current_user_field = "timesignature"
|
|
|
| self._apply_whitelist_inplace(scores, [value_tokens[0]])
|
| return scores
|
|
|
|
|
| token_prefix = tuple(self.accumulated_token_ids)
|
| if token_prefix in self.timesig_prefix_tree and self.newline_token in self.timesig_prefix_tree[token_prefix]:
|
|
|
| if self.newline_token:
|
| self._apply_whitelist_inplace(scores, [self.newline_token])
|
| else:
|
|
|
| allowed = self._get_allowed_timesig_tokens()
|
| self._apply_whitelist_inplace(scores, allowed)
|
|
|
| return scores
|
|
|
| def _transition_to_next_state(self):
|
| """Transition to the next FSM state."""
|
| if self.state in self.next_state:
|
| old_state = self.state
|
| next_state = self.next_state[self.state]
|
|
|
|
|
|
|
|
|
|
|
| if old_state == FSMState.THINK_END_TAG:
|
| if self.generation_phase == "understand":
|
|
|
|
|
| next_state = FSMState.COMPLETED
|
| if self.debug:
|
| logger.debug(f"generation_phase='understand': allowing free-form lyrics after </think>")
|
|
|
|
|
| self.state = next_state
|
| self.position_in_state = 0
|
| self.accumulated_value = ""
|
| self.accumulated_token_ids = []
|
| self.caption_after_newline = False
|
| self.caption_token_count = 0
|
| self.caption_ending = False
|
| self.pending_field_name = ""
|
| if self.debug:
|
| logger.debug(f"FSM transition: {old_state.name} -> {self.state.name}")
|
|
|
| def update_state(self, generated_token_id: int):
|
| """
|
| Update internal state after a token has been generated.
|
| This should be called after each token generation.
|
|
|
| Args:
|
| generated_token_id: The token ID that was just generated
|
| """
|
| if not self.enabled:
|
| return
|
|
|
| if self.state == FSMState.COMPLETED:
|
| return
|
|
|
| if self.state == FSMState.CODES_GENERATION:
|
|
|
| self.codes_count += 1
|
| if self.debug and self.target_codes is not None:
|
| logger.debug(f"Codes count: {self.codes_count}/{self.target_codes}")
|
| return
|
|
|
|
|
| if self.user_field_token_queue:
|
|
|
| expected_token = self.user_field_token_queue[0]
|
| if generated_token_id != expected_token:
|
| if self.debug:
|
| logger.warning(f"Expected token {expected_token} but got {generated_token_id} for user-provided field {self.current_user_field}")
|
|
|
|
|
| self.user_field_token_queue.pop(0)
|
|
|
|
|
| if not self.user_field_token_queue:
|
| if self.debug:
|
| logger.debug(f"Completed injection of user-provided field: {self.current_user_field}")
|
| field_name = self.current_user_field
|
| self.current_user_field = None
|
|
|
|
|
|
|
| next_state = self._get_next_field_state(field_name)
|
| if next_state:
|
| old_state = self.state
|
| self.state = next_state
|
| self.position_in_state = 0
|
| self.accumulated_value = ""
|
| self.accumulated_token_ids = []
|
| if self.debug:
|
| logger.debug(f"FSM transition (after user field injection): {old_state.name} -> {self.state.name}")
|
| else:
|
|
|
| self._transition_to_next_state()
|
| return
|
|
|
| token_str = self.tokenizer.decode([generated_token_id])
|
|
|
| if self.debug:
|
| logger.debug(f"Generated token: {repr(token_str)} (id={generated_token_id}), state={self.state.name}")
|
|
|
| if self.state in self.fixed_strings:
|
|
|
| fixed_str = self.fixed_strings[self.state]
|
| self.position_in_state += len(token_str)
|
|
|
|
|
| if self.position_in_state >= len(fixed_str):
|
|
|
| if self.state == FSMState.THINK_END_TAG and self.stop_at_reasoning:
|
|
|
|
|
| old_state = self.state
|
| self.state = FSMState.COMPLETED
|
| self.position_in_state = 0
|
| self.accumulated_value = ""
|
| self.accumulated_token_ids = []
|
| if self.debug:
|
| logger.debug(f"FSM transition (stop_at_reasoning): {old_state.name} -> {self.state.name}")
|
| else:
|
| self._transition_to_next_state()
|
|
|
| elif self.state in [FSMState.BPM_VALUE, FSMState.DURATION_VALUE, FSMState.TIMESIG_VALUE]:
|
|
|
| if generated_token_id == self.newline_token:
|
| old_state = self.state
|
| self._transition_to_next_state()
|
|
|
|
|
|
|
|
|
| if self.state in self.fixed_strings:
|
| return
|
| else:
|
|
|
| self.accumulated_token_ids.append(generated_token_id)
|
|
|
| if token_str.strip().isdigit():
|
| self.accumulated_value += token_str.strip()
|
|
|
| elif self.state == FSMState.GENRES_VALUE:
|
| if generated_token_id == self.newline_token:
|
|
|
| self._transition_to_next_state()
|
|
|
|
|
|
|
|
|
| if self.state in self.fixed_strings:
|
| return
|
| else:
|
|
|
| self.accumulated_value += token_str
|
|
|
| elif self.state == FSMState.CAPTION_VALUE:
|
|
|
| self.caption_token_count += 1
|
|
|
|
|
| self.accumulated_value += token_str
|
|
|
|
|
|
|
| if '\n' in token_str:
|
|
|
| self.caption_after_newline = True
|
| else:
|
|
|
|
|
| self.caption_after_newline = False
|
|
|
|
|
| if self.caption_ending:
|
| self.pending_field_name += token_str
|
|
|
|
|
| if ':' in token_str or token_str.strip() == ':':
|
|
|
| field_name_full = self.pending_field_name.strip()
|
|
|
| field_name = field_name_full.rstrip(':').strip().lower()
|
|
|
| if self.debug:
|
| logger.debug(f"Detected field name after caption: {repr(field_name)}")
|
|
|
|
|
| field_name_to_value_state = {
|
| "duration": FSMState.DURATION_VALUE,
|
| "genres": FSMState.GENRES_VALUE,
|
| "keyscale": FSMState.KEYSCALE_VALUE,
|
| "language": FSMState.LANGUAGE_VALUE,
|
| "timesignature": FSMState.TIMESIG_VALUE,
|
| }
|
|
|
| if field_name in field_name_to_value_state:
|
|
|
| old_state = self.state
|
| self.state = field_name_to_value_state[field_name]
|
| self.position_in_state = 0
|
| self.accumulated_value = ""
|
| self.accumulated_token_ids = []
|
| self.caption_ending = False
|
| self.pending_field_name = ""
|
|
|
| if self.debug:
|
| logger.debug(f"FSM transition (caption ending): {old_state.name} -> {self.state.name}")
|
| else:
|
|
|
| if self.debug:
|
| logger.warning(f"Unknown field name after caption: {repr(field_name)}, forcing transition")
|
| self.caption_ending = False
|
| self.pending_field_name = ""
|
| self._transition_to_next_state()
|
|
|
| elif self.state == FSMState.KEYSCALE_VALUE:
|
| if generated_token_id == self.newline_token:
|
|
|
| self._transition_to_next_state()
|
|
|
|
|
|
|
|
|
| if self.state in self.fixed_strings:
|
| return
|
| else:
|
|
|
| self.accumulated_token_ids.append(generated_token_id)
|
|
|
| self.accumulated_value += token_str
|
|
|
| elif self.state == FSMState.LANGUAGE_VALUE:
|
| if generated_token_id == self.newline_token:
|
|
|
| self._transition_to_next_state()
|
| if self.state in self.fixed_strings:
|
| return
|
| else:
|
|
|
| self.accumulated_token_ids.append(generated_token_id)
|
|
|
| self.accumulated_value += token_str
|
|
|
|
|
|
|