|
|
|
|
|
""" |
|
|
Constrained Logits Processor for ACE-Step Language Model |
|
|
|
|
|
This module implements a finite state machine (FSM) based logits processor that constrains |
|
|
the language model's output to follow specific formats and value ranges during music generation. |
|
|
|
|
|
Key Features: |
|
|
- Enforces structured metadata generation (BPM, duration, keyscale, etc.) |
|
|
- Validates numeric ranges (BPM: 30-300, Duration: 10-600s) |
|
|
- Ensures proper formatting for musical metadata |
|
|
- Prevents generation of invalid tokens or formats |
|
|
- Supports constrained audio code generation (0-63999) |
|
|
|
|
|
The FSM guides the model through different states to ensure outputs conform to expected |
|
|
schema requirements without post-processing corrections. |
|
|
|
|
|
Usage: |
|
|
processor = ConstrainedLogitsProcessor(tokenizer, mode="metadata") |
|
|
outputs = model.generate(inputs, logits_processor=[processor]) |
|
|
""" |
|
|
|
|
|
from enum import Enum, auto |
|
|
from typing import Optional, Dict, Any, Tuple, List, Callable, Set |
|
|
from loguru import logger |
|
|
from transformers import AutoTokenizer |
|
|
from transformers.generation.logits_process import LogitsProcessor |
|
|
import os |
|
|
import torch |
|
|
from acestep.constants import ( |
|
|
VALID_LANGUAGES, |
|
|
KEYSCALE_NOTES, |
|
|
KEYSCALE_ACCIDENTALS, |
|
|
KEYSCALE_MODES, |
|
|
VALID_KEYSCALES, |
|
|
BPM_MIN, |
|
|
BPM_MAX, |
|
|
DURATION_MIN, |
|
|
DURATION_MAX, |
|
|
VALID_TIME_SIGNATURES, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MAX_AUDIO_CODE = 63999 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class FSMState(Enum): |
|
|
"""Finite State Machine states for metadata generation""" |
|
|
THINK_TAG = auto() |
|
|
NEWLINE_AFTER_THINK = auto() |
|
|
BPM_NAME = auto() |
|
|
BPM_VALUE = auto() |
|
|
NEWLINE_AFTER_BPM = auto() |
|
|
CAPTION_NAME = auto() |
|
|
CAPTION_VALUE = auto() |
|
|
DURATION_NAME = auto() |
|
|
DURATION_VALUE = auto() |
|
|
NEWLINE_AFTER_DURATION = auto() |
|
|
GENRES_NAME = auto() |
|
|
GENRES_VALUE = auto() |
|
|
NEWLINE_AFTER_GENRES = auto() |
|
|
KEYSCALE_NAME = auto() |
|
|
KEYSCALE_VALUE = auto() |
|
|
NEWLINE_AFTER_KEYSCALE = auto() |
|
|
LANGUAGE_NAME = auto() |
|
|
LANGUAGE_VALUE = auto() |
|
|
TIMESIG_NAME = auto() |
|
|
TIMESIG_VALUE = auto() |
|
|
NEWLINE_AFTER_TIMESIG = auto() |
|
|
THINK_END_TAG = auto() |
|
|
CODES_GENERATION = auto() |
|
|
COMPLETED = auto() |
|
|
|
|
|
|
|
|
class MetadataConstrainedLogitsProcessor(LogitsProcessor): |
|
|
""" |
|
|
FSM-driven LogitsProcessor that constrains generation to produce valid metadata. |
|
|
|
|
|
This processor enforces the following format: |
|
|
<think> |
|
|
bpm: [30-300] |
|
|
caption: [text without code blocks, ends with period + newline] |
|
|
duration: [10-600] |
|
|
keyscale: [A-G][#/♭]? [major/minor] |
|
|
language: [en/zh/ja/ko/es/fr/de/uk/ru/...] |
|
|
timesignature: [2/3/4/6] |
|
|
</think> |
|
|
|
|
|
It uses token masking (setting invalid token logits to -inf) to enforce constraints. |
|
|
For numeric fields, it uses early-blocking to prevent out-of-range values. |
|
|
For field transitions (e.g., end of numeric value), it compares P(newline) vs P(digit). |
|
|
For caption field, it blocks code blocks and newlines, and only transitions when |
|
|
the previous token was a period and newline has the highest probability. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
tokenizer: AutoTokenizer, |
|
|
enabled: bool = True, |
|
|
debug: bool = False, |
|
|
genres_vocab_path: Optional[str] = None, |
|
|
skip_genres: bool = True, |
|
|
max_duration: Optional[int] = None, |
|
|
): |
|
|
""" |
|
|
Initialize the constrained logits processor. |
|
|
|
|
|
This processor should be initialized once when loading the LLM and reused |
|
|
for all generations. |
|
|
Args: |
|
|
tokenizer: The tokenizer to use for encoding/decoding |
|
|
enabled: Whether to enable constrained decoding |
|
|
debug: Whether to print debug information |
|
|
genres_vocab_path: Path to genres vocabulary file |
|
|
skip_genres: Whether to skip genres field generation |
|
|
max_duration: Maximum duration in seconds (default: DURATION_MAX from constants) |
|
|
""" |
|
|
self.tokenizer = tokenizer |
|
|
self.enabled = enabled |
|
|
self.debug = debug |
|
|
self.skip_genres = skip_genres |
|
|
|
|
|
|
|
|
self.max_duration = max_duration if max_duration is not None else DURATION_MAX |
|
|
self.skip_caption = False |
|
|
self.skip_language = False |
|
|
self.caption: Optional[str] = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.user_provided_metadata: Dict[str, Optional[str]] = { |
|
|
"bpm": None, |
|
|
"caption": None, |
|
|
"duration": None, |
|
|
"keyscale": None, |
|
|
"language": None, |
|
|
"timesignature": None, |
|
|
"genres": None, |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.metadata_temperature: Optional[float] = None |
|
|
self.codes_temperature: Optional[float] = None |
|
|
|
|
|
|
|
|
|
|
|
self.target_duration: Optional[float] = None |
|
|
self.target_codes: Optional[int] = None |
|
|
self.codes_count: int = 0 |
|
|
|
|
|
|
|
|
self.stop_at_reasoning: bool = False |
|
|
|
|
|
|
|
|
|
|
|
self.generation_phase: str = "cot" |
|
|
|
|
|
|
|
|
self.state = FSMState.THINK_TAG |
|
|
self.position_in_state = 0 |
|
|
self.accumulated_value = "" |
|
|
self.accumulated_token_ids: List[int] = [] |
|
|
|
|
|
|
|
|
self.caption_after_newline = False |
|
|
self.caption_token_count = 0 |
|
|
self.caption_ending = False |
|
|
self.pending_field_name = "" |
|
|
|
|
|
|
|
|
self.user_field_token_queue: List[int] = [] |
|
|
self.current_user_field: Optional[str] = None |
|
|
|
|
|
|
|
|
self._precompute_tokens() |
|
|
|
|
|
|
|
|
self.genres_vocab_path = genres_vocab_path or os.path.join( |
|
|
os.path.dirname(os.path.abspath(__file__)), "genres_vocab.txt" |
|
|
) |
|
|
self.genres_vocab: List[str] = [] |
|
|
self.genres_vocab_mtime: float = 0.0 |
|
|
self.genres_trie: Dict = {} |
|
|
self.caption_genres_trie: Dict = {} |
|
|
self.caption_matched_genres: List[str] = [] |
|
|
|
|
|
self._char_to_tokens: Dict[str, set] = {} |
|
|
|
|
|
|
|
|
self._precompute_char_token_mapping() |
|
|
|
|
|
|
|
|
|
|
|
self.field_specs = { |
|
|
"bpm": {"min": BPM_MIN, "max": BPM_MAX}, |
|
|
"duration": {"min": DURATION_MIN, "max": self.max_duration}, |
|
|
"timesignature": {"valid_values": VALID_TIME_SIGNATURES}, |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
self.valid_bpm_values = [str(v) for v in range(self.field_specs["bpm"]["min"], self.field_specs["bpm"]["max"] + 1)] |
|
|
self.valid_duration_values = [str(v) for v in range(self.field_specs["duration"]["min"], self.field_specs["duration"]["max"] + 1)] |
|
|
self.valid_timesig_values = [str(v) for v in self.field_specs["timesignature"]["valid_values"]] |
|
|
|
|
|
|
|
|
self.keyscale_prefix_tree = self._build_keyscale_prefix_tree() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.bpm_prefix_tree = self._build_numeric_prefix_tree( |
|
|
self.valid_bpm_values, |
|
|
context_prefix_for_matching="bpm:", |
|
|
context_prefix_for_tokenization="bpm: " |
|
|
) |
|
|
self.duration_prefix_tree = self._build_numeric_prefix_tree( |
|
|
self.valid_duration_values, |
|
|
context_prefix_for_matching="duration:", |
|
|
context_prefix_for_tokenization="duration: " |
|
|
) |
|
|
self.timesig_prefix_tree = self._build_numeric_prefix_tree( |
|
|
self.valid_timesig_values, |
|
|
context_prefix_for_matching="timesignature:", |
|
|
context_prefix_for_tokenization="timesignature: " |
|
|
) |
|
|
|
|
|
|
|
|
self.language_prefix_tree = self._build_language_prefix_tree() |
|
|
|
|
|
self._load_genres_vocab() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.fixed_strings = { |
|
|
FSMState.THINK_TAG: "<think>", |
|
|
FSMState.NEWLINE_AFTER_THINK: "\n", |
|
|
FSMState.BPM_NAME: "bpm:", |
|
|
FSMState.CAPTION_NAME: "caption:", |
|
|
FSMState.DURATION_NAME: "duration:", |
|
|
FSMState.GENRES_NAME: "genres:", |
|
|
FSMState.KEYSCALE_NAME: "keyscale:", |
|
|
FSMState.LANGUAGE_NAME: "language:", |
|
|
FSMState.TIMESIG_NAME: "timesignature:", |
|
|
FSMState.THINK_END_TAG: "</think>", |
|
|
} |
|
|
|
|
|
|
|
|
self._build_state_transitions() |
|
|
|
|
|
def _get_next_field_state(self, current_field: str) -> Optional[FSMState]: |
|
|
""" |
|
|
Get the next field state. Always returns the next field's NAME state, |
|
|
even if the field is user-provided (we still need to generate the field name). |
|
|
|
|
|
Args: |
|
|
current_field: Current field name ("bpm", "caption", "duration", "genres", "keyscale", "language", "timesignature") |
|
|
|
|
|
Returns: |
|
|
Next FSMState (NAME state of next field), or THINK_END_TAG if no more fields |
|
|
""" |
|
|
|
|
|
|
|
|
field_order = ["bpm", "caption", "duration", "genres", "keyscale", "language", "timesignature"] |
|
|
field_to_state = { |
|
|
"bpm": FSMState.BPM_NAME, |
|
|
"caption": FSMState.CAPTION_NAME, |
|
|
"duration": FSMState.DURATION_NAME, |
|
|
"genres": FSMState.GENRES_NAME, |
|
|
"keyscale": FSMState.KEYSCALE_NAME, |
|
|
"language": FSMState.LANGUAGE_NAME, |
|
|
"timesignature": FSMState.TIMESIG_NAME, |
|
|
} |
|
|
|
|
|
try: |
|
|
current_idx = field_order.index(current_field) |
|
|
except ValueError: |
|
|
return FSMState.THINK_END_TAG |
|
|
|
|
|
|
|
|
for i in range(current_idx + 1, len(field_order)): |
|
|
field = field_order[i] |
|
|
|
|
|
|
|
|
if field == "genres" and self.skip_genres: |
|
|
continue |
|
|
if field == "caption" and self.skip_caption: |
|
|
continue |
|
|
if field == "language" and self.skip_language: |
|
|
continue |
|
|
|
|
|
|
|
|
return field_to_state[field] |
|
|
|
|
|
|
|
|
return FSMState.THINK_END_TAG |
|
|
|
|
|
def _build_state_transitions(self): |
|
|
"""Build state transition map based on user-provided metadata.""" |
|
|
self.next_state = { |
|
|
FSMState.THINK_TAG: FSMState.NEWLINE_AFTER_THINK, |
|
|
FSMState.NEWLINE_AFTER_THINK: FSMState.BPM_NAME, |
|
|
FSMState.THINK_END_TAG: FSMState.CODES_GENERATION, |
|
|
FSMState.CODES_GENERATION: FSMState.COMPLETED, |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.next_state[FSMState.BPM_NAME] = FSMState.BPM_VALUE |
|
|
self.next_state[FSMState.BPM_VALUE] = self._get_next_field_state("bpm") |
|
|
|
|
|
|
|
|
if not self.skip_caption: |
|
|
self.next_state[FSMState.CAPTION_NAME] = FSMState.CAPTION_VALUE |
|
|
self.next_state[FSMState.CAPTION_VALUE] = self._get_next_field_state("caption") |
|
|
|
|
|
|
|
|
self.next_state[FSMState.DURATION_NAME] = FSMState.DURATION_VALUE |
|
|
self.next_state[FSMState.DURATION_VALUE] = self._get_next_field_state("duration") |
|
|
|
|
|
|
|
|
if not self.skip_genres: |
|
|
self.next_state[FSMState.GENRES_NAME] = FSMState.GENRES_VALUE |
|
|
self.next_state[FSMState.GENRES_VALUE] = self._get_next_field_state("genres") |
|
|
|
|
|
|
|
|
self.next_state[FSMState.KEYSCALE_NAME] = FSMState.KEYSCALE_VALUE |
|
|
self.next_state[FSMState.KEYSCALE_VALUE] = self._get_next_field_state("keyscale") |
|
|
|
|
|
|
|
|
if not self.skip_language: |
|
|
self.next_state[FSMState.LANGUAGE_NAME] = FSMState.LANGUAGE_VALUE |
|
|
self.next_state[FSMState.LANGUAGE_VALUE] = self._get_next_field_state("language") |
|
|
|
|
|
|
|
|
self.next_state[FSMState.TIMESIG_NAME] = FSMState.TIMESIG_VALUE |
|
|
self.next_state[FSMState.TIMESIG_VALUE] = FSMState.THINK_END_TAG |
|
|
|
|
|
def set_skip_genres(self, skip: bool): |
|
|
"""Set whether to skip genres generation and rebuild state transitions.""" |
|
|
self.skip_genres = skip |
|
|
self._build_state_transitions() |
|
|
|
|
|
def set_skip_caption(self, skip: bool): |
|
|
"""Set whether to skip caption generation and rebuild state transitions.""" |
|
|
self.skip_caption = skip |
|
|
self._build_state_transitions() |
|
|
|
|
|
def set_skip_language(self, skip: bool): |
|
|
"""Set whether to skip language generation and rebuild state transitions.""" |
|
|
self.skip_language = skip |
|
|
self._build_state_transitions() |
|
|
|
|
|
@staticmethod |
|
|
def postprocess_caption(caption: str) -> str: |
|
|
""" |
|
|
Post-process caption to remove YAML multi-line formatting. |
|
|
Converts YAML-style multi-line text (with newlines and leading spaces) |
|
|
to a single-line string. |
|
|
|
|
|
Example: |
|
|
Input: "An emotional ballad.\\n The track opens with piano.\\n More text." |
|
|
Output: "An emotional ballad. The track opens with piano. More text." |
|
|
|
|
|
Args: |
|
|
caption: Raw caption text with possible YAML formatting |
|
|
|
|
|
Returns: |
|
|
Clean single-line caption |
|
|
""" |
|
|
if not caption: |
|
|
return caption |
|
|
|
|
|
|
|
|
lines = caption.split('\n') |
|
|
|
|
|
|
|
|
cleaned_lines = [] |
|
|
for line in lines: |
|
|
stripped = line.strip() |
|
|
if stripped: |
|
|
cleaned_lines.append(stripped) |
|
|
|
|
|
|
|
|
return ' '.join(cleaned_lines) |
|
|
|
|
|
def set_stop_at_reasoning(self, stop: bool): |
|
|
""" |
|
|
Set whether to stop generation after </think> tag. |
|
|
|
|
|
Args: |
|
|
stop: If True, generation will stop immediately after </think> tag is generated. |
|
|
If False, generation continues to codes generation phase. |
|
|
""" |
|
|
self.stop_at_reasoning = stop |
|
|
|
|
|
def set_generation_phase(self, phase: str): |
|
|
""" |
|
|
Set the generation phase. |
|
|
|
|
|
Args: |
|
|
phase: "cot" for CoT metadata generation, "codes" for audio codes generation, |
|
|
or "understand" for audio understanding (codes → metadata + lyrics). |
|
|
When phase is "codes" and the input prompt already contains </think>, |
|
|
the FSM will skip metadata generation and go directly to codes generation. |
|
|
When phase is "understand", generate CoT metadata then free-form lyrics. |
|
|
""" |
|
|
if phase not in ("cot", "codes", "understand"): |
|
|
raise ValueError(f"Invalid generation phase: {phase!r}. Must be 'cot', 'codes', or 'understand'") |
|
|
self.generation_phase = phase |
|
|
|
|
|
def set_user_metadata(self, metadata: Optional[Dict[str, Optional[str]]] = None): |
|
|
""" |
|
|
Set user-provided metadata fields. Fields that are provided will be used directly |
|
|
instead of generating. Fields that are None will be generated. |
|
|
|
|
|
Args: |
|
|
metadata: Dictionary with optional fields: |
|
|
- "bpm": Optional[str] - e.g., "120" |
|
|
- "caption": Optional[str] - e.g., "A melodic piano piece..." |
|
|
- "duration": Optional[str] - e.g., "234" |
|
|
- "keyscale": Optional[str] - e.g., "G major" |
|
|
- "language": Optional[str] - e.g., "en" |
|
|
- "timesignature": Optional[str] - e.g., "4" |
|
|
- "genres": Optional[str] - e.g., "Pop Rock" |
|
|
If None, clears all user-provided metadata. |
|
|
""" |
|
|
if metadata is None: |
|
|
metadata = {} |
|
|
|
|
|
|
|
|
for field in ["bpm", "caption", "duration", "keyscale", "language", "timesignature", "genres"]: |
|
|
if field in metadata: |
|
|
self.user_provided_metadata[field] = metadata[field] |
|
|
else: |
|
|
self.user_provided_metadata[field] = None |
|
|
|
|
|
|
|
|
self._build_state_transitions() |
|
|
|
|
|
if self.debug: |
|
|
provided_fields = [k for k, v in self.user_provided_metadata.items() if v is not None] |
|
|
if provided_fields: |
|
|
logger.debug(f"User provided metadata fields: {provided_fields}") |
|
|
else: |
|
|
logger.debug("No user-provided metadata, all fields will be generated") |
|
|
|
|
|
def _precompute_tokens(self): |
|
|
"""Pre-compute commonly used token IDs for efficiency.""" |
|
|
|
|
|
self.digit_tokens = {} |
|
|
for d in range(10): |
|
|
tokens = self.tokenizer.encode(str(d), add_special_tokens=False) |
|
|
if tokens: |
|
|
self.digit_tokens[d] = tokens[-1] |
|
|
|
|
|
|
|
|
newline_tokens = self.tokenizer.encode("\n", add_special_tokens=False) |
|
|
self.newline_token = newline_tokens[-1] if newline_tokens else None |
|
|
|
|
|
|
|
|
self.note_tokens = {} |
|
|
for note in KEYSCALE_NOTES: |
|
|
tokens = self.tokenizer.encode(note, add_special_tokens=False) |
|
|
if tokens: |
|
|
self.note_tokens[note] = tokens[-1] |
|
|
|
|
|
|
|
|
self.sharp_tokens = [] |
|
|
for s in ["#", "♯"]: |
|
|
tokens = self.tokenizer.encode(s, add_special_tokens=False) |
|
|
if tokens: |
|
|
self.sharp_tokens.append(tokens[-1]) |
|
|
|
|
|
self.flat_tokens = [] |
|
|
for f in ["b", "♭"]: |
|
|
tokens = self.tokenizer.encode(f, add_special_tokens=False) |
|
|
if tokens: |
|
|
self.flat_tokens.append(tokens[-1]) |
|
|
|
|
|
|
|
|
space_tokens = self.tokenizer.encode(" ", add_special_tokens=False) |
|
|
self.space_token = space_tokens[-1] if space_tokens else None |
|
|
|
|
|
|
|
|
self.major_start_tokens = [] |
|
|
self.minor_start_tokens = [] |
|
|
for prefix in ["m", "M"]: |
|
|
tokens = self.tokenizer.encode(prefix, add_special_tokens=False) |
|
|
if tokens: |
|
|
if prefix.lower() == "m": |
|
|
self.minor_start_tokens.append(tokens[-1]) |
|
|
self.major_start_tokens.append(tokens[-1]) |
|
|
|
|
|
|
|
|
self.vocab_size = len(self.tokenizer) |
|
|
|
|
|
|
|
|
comma_tokens = self.tokenizer.encode(",", add_special_tokens=False) |
|
|
self.comma_token = comma_tokens[-1] if comma_tokens else None |
|
|
|
|
|
|
|
|
self.eos_token_id = self.tokenizer.eos_token_id |
|
|
|
|
|
|
|
|
period_tokens = self.tokenizer.encode(".", add_special_tokens=False) |
|
|
self.period_token = period_tokens[-1] if period_tokens else None |
|
|
|
|
|
|
|
|
backtick_tokens = self.tokenizer.encode("`", add_special_tokens=False) |
|
|
self.backtick_token = backtick_tokens[-1] if backtick_tokens else None |
|
|
|
|
|
|
|
|
self.valid_languages = VALID_LANGUAGES |
|
|
|
|
|
|
|
|
|
|
|
self.audio_code_token_ids: Set[int] = set() |
|
|
self._precompute_audio_code_tokens() |
|
|
|
|
|
|
|
|
|
|
|
self.audio_code_mask: Optional[torch.Tensor] = None |
|
|
|
|
|
self.non_audio_code_mask: Optional[torch.Tensor] = None |
|
|
self._build_audio_code_mask() |
|
|
|
|
|
|
|
|
|
|
|
self.valid_keyscales = VALID_KEYSCALES.copy() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _precompute_audio_code_tokens(self): |
|
|
""" |
|
|
Precompute audio code token IDs (tokens matching <|audio_code_\\d+|>). |
|
|
These tokens should be blocked during caption generation. |
|
|
Only tokens with code values in range [0, MAX_AUDIO_CODE] are included. |
|
|
""" |
|
|
import re |
|
|
audio_code_pattern = re.compile(r'^<\|audio_code_(\d+)\|>$') |
|
|
invalid_tokens_count = 0 |
|
|
|
|
|
|
|
|
for token_id in range(self.vocab_size): |
|
|
try: |
|
|
token_text = self.tokenizer.decode([token_id]) |
|
|
match = audio_code_pattern.match(token_text) |
|
|
if match: |
|
|
|
|
|
code_value = int(match.group(1)) |
|
|
|
|
|
if 0 <= code_value <= MAX_AUDIO_CODE: |
|
|
self.audio_code_token_ids.add(token_id) |
|
|
else: |
|
|
invalid_tokens_count += 1 |
|
|
if self.debug: |
|
|
logger.debug(f"Skipping audio code token {token_id} with invalid code value {code_value} (max: {MAX_AUDIO_CODE})") |
|
|
except Exception: |
|
|
continue |
|
|
|
|
|
if invalid_tokens_count > 0: |
|
|
logger.debug(f"Found {invalid_tokens_count} audio code tokens with values outside valid range [0, {MAX_AUDIO_CODE}]") |
|
|
|
|
|
|
|
|
if len(self.audio_code_token_ids) == 0: |
|
|
logger.warning(f"No valid audio code tokens found in vocabulary (range [0, {MAX_AUDIO_CODE}]). Code generation may fail.") |
|
|
elif self.debug: |
|
|
logger.debug(f"Found {len(self.audio_code_token_ids)} valid audio code tokens (range [0, {MAX_AUDIO_CODE}])") |
|
|
|
|
|
def _extract_code_from_token(self, token_id: int) -> Optional[int]: |
|
|
""" |
|
|
Extract audio code value from a token ID. |
|
|
|
|
|
Args: |
|
|
token_id: Token ID to extract code value from |
|
|
|
|
|
Returns: |
|
|
Code value if token is a valid audio code token, None otherwise |
|
|
""" |
|
|
import re |
|
|
audio_code_pattern = re.compile(r'^<\|audio_code_(\d+)\|>$') |
|
|
|
|
|
try: |
|
|
token_text = self.tokenizer.decode([token_id]) |
|
|
match = audio_code_pattern.match(token_text) |
|
|
if match: |
|
|
return int(match.group(1)) |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
return None |
|
|
|
|
|
def _build_audio_code_mask(self): |
|
|
""" |
|
|
Build a precomputed mask tensor for blocking audio code tokens. |
|
|
This mask can be added to scores in O(1) time instead of O(n) loop. |
|
|
|
|
|
The mask is [1, vocab_size] tensor with -inf at audio code token positions. |
|
|
|
|
|
Also builds the inverse mask (non_audio_code_mask) for CODES_GENERATION state, |
|
|
which blocks all non-audio-code tokens. |
|
|
""" |
|
|
if not self.audio_code_token_ids: |
|
|
self.audio_code_mask = None |
|
|
self.non_audio_code_mask = None |
|
|
return |
|
|
|
|
|
|
|
|
|
|
|
mask = torch.zeros(1, self.vocab_size, dtype=torch.float32) |
|
|
|
|
|
|
|
|
audio_code_indices = list(self.audio_code_token_ids) |
|
|
|
|
|
|
|
|
mask[0, audio_code_indices] = float('-inf') |
|
|
|
|
|
self.audio_code_mask = mask |
|
|
|
|
|
|
|
|
|
|
|
inverse_mask = torch.full((1, self.vocab_size), float('-inf'), dtype=torch.float32) |
|
|
inverse_mask[0, audio_code_indices] = 0 |
|
|
|
|
|
|
|
|
if self.eos_token_id is not None: |
|
|
inverse_mask[0, self.eos_token_id] = 0 |
|
|
|
|
|
self.non_audio_code_mask = inverse_mask |
|
|
|
|
|
if self.debug: |
|
|
logger.debug(f"Built audio code masks for {len(self.audio_code_token_ids)} tokens") |
|
|
|
|
|
def _apply_whitelist_inplace(self, scores: torch.Tensor, allowed_tokens: List[int]) -> None: |
|
|
""" |
|
|
Apply whitelist constraint inplace: only allow specified tokens, block all others. |
|
|
|
|
|
This is more efficient than creating a mask tensor because: |
|
|
1. No memory allocation for mask |
|
|
2. No tensor addition operation |
|
|
|
|
|
Args: |
|
|
scores: [1, vocab_size] scores tensor to modify inplace |
|
|
allowed_tokens: List of token IDs to allow (all others will be set to -inf) |
|
|
""" |
|
|
if not allowed_tokens: |
|
|
|
|
|
scores.fill_(float('-inf')) |
|
|
return |
|
|
|
|
|
|
|
|
allowed_indices = torch.tensor(allowed_tokens, device=scores.device, dtype=torch.long) |
|
|
saved_values = scores[0, allowed_indices].clone() |
|
|
|
|
|
|
|
|
scores.fill_(float('-inf')) |
|
|
|
|
|
|
|
|
scores[0, allowed_indices] = saved_values |
|
|
|
|
|
def _build_keyscale_prefix_tree(self) -> Dict[Tuple[int, ...], Set[int]]: |
|
|
""" |
|
|
Build keyscale prefix to allowed tokens mapping based on ACTUAL tokenization. |
|
|
|
|
|
IMPORTANT: Uses token ID sequences as keys, NOT strings, to avoid tokenization mismatches. |
|
|
|
|
|
CRITICAL FIX: The tokenizer may merge the context's trailing space into the next token. |
|
|
For example: |
|
|
- "keyscale: " tokenizes to [10563, 2246, 25, 220] -> ['keys', 'cale', ':', ' '] |
|
|
- "keyscale: G major" tokenizes to [10563, 2246, 25, 479, 3598] -> ['keys', 'cale', ':', ' G', ' major'] |
|
|
The space ' ' (220) is merged into ' G' (479), so we can't use simple slicing. |
|
|
|
|
|
Strategy: |
|
|
1. For each keyscale (e.g., "G major"), encode the FULL string "keyscale: G major" |
|
|
2. Tokenize to get: [10563, 2246, 25, 479, 3598] -> ['keys', 'cale', ':', ' G', ' major'] |
|
|
3. Find where context prefix ends by matching token sequences (handling space merging) |
|
|
4. Extract keyscale value tokens: [479, 3598] (for "G major") |
|
|
5. Build prefix tree using token ID sequences as keys |
|
|
|
|
|
This ensures we get the exact tokenization that occurs during generation. |
|
|
""" |
|
|
prefix_to_tokens: Dict[Tuple[int, ...], Set[int]] = {} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
context_prefix_for_matching = "keyscale:" |
|
|
context_prefix_for_tokenization = "keyscale: " |
|
|
|
|
|
|
|
|
context_token_ids = self.tokenizer.encode(context_prefix_for_matching, add_special_tokens=False) |
|
|
|
|
|
if self.debug: |
|
|
context_tokens_str = [self.tokenizer.decode([t]) for t in context_token_ids] |
|
|
logger.debug(f"Context for matching 'keyscale:' tokenizes to {context_token_ids} -> {context_tokens_str}") |
|
|
|
|
|
|
|
|
for keyscale in self.valid_keyscales: |
|
|
|
|
|
full_text = context_prefix_for_tokenization + keyscale |
|
|
full_token_ids = self.tokenizer.encode(full_text, add_special_tokens=False) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
context_end_idx = None |
|
|
|
|
|
|
|
|
if len(full_token_ids) >= len(context_token_ids): |
|
|
if full_token_ids[:len(context_token_ids)] == context_token_ids: |
|
|
context_end_idx = len(context_token_ids) |
|
|
|
|
|
if context_end_idx is None: |
|
|
if self.debug: |
|
|
logger.warning(f"Could not find context prefix in full tokenization of '{full_text}', skipping") |
|
|
continue |
|
|
|
|
|
|
|
|
keyscale_token_ids = full_token_ids[context_end_idx:] |
|
|
|
|
|
|
|
|
if not keyscale_token_ids: |
|
|
if self.debug: |
|
|
logger.warning(f"No tokens extracted for keyscale '{keyscale}', skipping") |
|
|
continue |
|
|
|
|
|
|
|
|
|
|
|
first_token_id = keyscale_token_ids[0] |
|
|
first_token_str = self.tokenizer.decode([first_token_id]) |
|
|
|
|
|
first_char = first_token_str.lstrip()[0].upper() if first_token_str.lstrip() else "" |
|
|
if first_char not in "ABCDEFG": |
|
|
|
|
|
if self.debug: |
|
|
logger.debug(f"Skipping keyscale '{keyscale}': first token is '{first_token_str}' (id={first_token_id}), not a note") |
|
|
continue |
|
|
|
|
|
|
|
|
|
|
|
for i in range(len(keyscale_token_ids) + 1): |
|
|
|
|
|
token_prefix = tuple(keyscale_token_ids[:i]) |
|
|
|
|
|
if token_prefix not in prefix_to_tokens: |
|
|
prefix_to_tokens[token_prefix] = set() |
|
|
|
|
|
if i < len(keyscale_token_ids): |
|
|
|
|
|
next_token_id = keyscale_token_ids[i] |
|
|
prefix_to_tokens[token_prefix].add(next_token_id) |
|
|
else: |
|
|
|
|
|
if self.newline_token: |
|
|
prefix_to_tokens[token_prefix].add(self.newline_token) |
|
|
|
|
|
if self.debug: |
|
|
logger.debug(f"Built keyscale prefix tree with {len(prefix_to_tokens)} token sequence prefixes") |
|
|
|
|
|
empty_prefix = tuple() |
|
|
if empty_prefix in prefix_to_tokens: |
|
|
first_tokens = prefix_to_tokens[empty_prefix] |
|
|
decoded_first = [(t, repr(self.tokenizer.decode([t]))) for t in sorted(first_tokens)] |
|
|
logger.debug(f"First tokens allowed (empty prefix): {decoded_first}") |
|
|
|
|
|
return prefix_to_tokens |
|
|
|
|
|
def _build_numeric_prefix_tree( |
|
|
self, |
|
|
valid_values: List[str], |
|
|
context_prefix_for_matching: str = "", |
|
|
context_prefix_for_tokenization: str = "" |
|
|
) -> Dict[Tuple[int, ...], Set[int]]: |
|
|
""" |
|
|
Build prefix tree for numeric field based on actual tokenization with context. |
|
|
|
|
|
IMPORTANT: Uses token ID sequences as keys, NOT strings, to avoid tokenization mismatches. |
|
|
|
|
|
Args: |
|
|
valid_values: List of valid numeric strings (e.g., ["30", "31", ..., "300"]) |
|
|
context_prefix_for_matching: Context string that state machine generates (e.g., "bpm:") - no space |
|
|
context_prefix_for_tokenization: Context string for tokenization (e.g., "bpm: ") - with space |
|
|
|
|
|
Returns: |
|
|
Dict mapping token ID sequence prefix -> set of allowed token IDs |
|
|
""" |
|
|
prefix_to_tokens: Dict[Tuple[int, ...], Set[int]] = {} |
|
|
|
|
|
|
|
|
context_token_ids = self.tokenizer.encode(context_prefix_for_matching, add_special_tokens=False) if context_prefix_for_matching else [] |
|
|
|
|
|
|
|
|
for value_str in valid_values: |
|
|
|
|
|
full_text = context_prefix_for_tokenization + value_str |
|
|
token_ids = self.tokenizer.encode(full_text, add_special_tokens=False) |
|
|
|
|
|
|
|
|
context_end_idx = None |
|
|
if len(token_ids) >= len(context_token_ids): |
|
|
if token_ids[:len(context_token_ids)] == context_token_ids: |
|
|
context_end_idx = len(context_token_ids) |
|
|
|
|
|
if context_end_idx is None: |
|
|
if self.debug: |
|
|
logger.warning(f"Could not find context prefix in full tokenization of '{full_text}', skipping") |
|
|
continue |
|
|
|
|
|
|
|
|
value_token_ids = token_ids[context_end_idx:] |
|
|
|
|
|
|
|
|
for i in range(len(value_token_ids) + 1): |
|
|
|
|
|
token_prefix = tuple(value_token_ids[:i]) |
|
|
|
|
|
if token_prefix not in prefix_to_tokens: |
|
|
prefix_to_tokens[token_prefix] = set() |
|
|
|
|
|
if i < len(value_token_ids): |
|
|
|
|
|
next_token_id = value_token_ids[i] |
|
|
prefix_to_tokens[token_prefix].add(next_token_id) |
|
|
else: |
|
|
|
|
|
if self.newline_token: |
|
|
prefix_to_tokens[token_prefix].add(self.newline_token) |
|
|
|
|
|
return prefix_to_tokens |
|
|
|
|
|
def _build_language_prefix_tree(self) -> Dict[Tuple[int, ...], Set[int]]: |
|
|
""" |
|
|
Build language prefix to allowed tokens mapping based on ACTUAL tokenization. |
|
|
Similar to keyscale prefix tree but for language codes. |
|
|
|
|
|
Uses token ID sequences as keys, NOT strings, to avoid tokenization mismatches. |
|
|
""" |
|
|
prefix_to_tokens: Dict[Tuple[int, ...], Set[int]] = {} |
|
|
|
|
|
context_prefix_for_matching = "language:" |
|
|
context_prefix_for_tokenization = "language: " |
|
|
|
|
|
context_token_ids = self.tokenizer.encode(context_prefix_for_matching, add_special_tokens=False) |
|
|
|
|
|
if self.debug: |
|
|
context_tokens_str = [self.tokenizer.decode([t]) for t in context_token_ids] |
|
|
logger.debug(f"Context for matching 'language:' tokenizes to {context_token_ids} -> {context_tokens_str}") |
|
|
|
|
|
for lang in self.valid_languages: |
|
|
full_text = context_prefix_for_tokenization + lang |
|
|
full_token_ids = self.tokenizer.encode(full_text, add_special_tokens=False) |
|
|
|
|
|
context_end_idx = None |
|
|
if len(full_token_ids) >= len(context_token_ids): |
|
|
if full_token_ids[:len(context_token_ids)] == context_token_ids: |
|
|
context_end_idx = len(context_token_ids) |
|
|
|
|
|
if context_end_idx is None: |
|
|
if self.debug: |
|
|
logger.warning(f"Could not find context prefix in full tokenization of '{full_text}', skipping") |
|
|
continue |
|
|
|
|
|
lang_token_ids = full_token_ids[context_end_idx:] |
|
|
|
|
|
if not lang_token_ids: |
|
|
if self.debug: |
|
|
logger.warning(f"No tokens extracted for language '{lang}', skipping") |
|
|
continue |
|
|
|
|
|
for i in range(len(lang_token_ids) + 1): |
|
|
token_prefix = tuple(lang_token_ids[:i]) |
|
|
|
|
|
if token_prefix not in prefix_to_tokens: |
|
|
prefix_to_tokens[token_prefix] = set() |
|
|
|
|
|
if i < len(lang_token_ids): |
|
|
next_token_id = lang_token_ids[i] |
|
|
prefix_to_tokens[token_prefix].add(next_token_id) |
|
|
else: |
|
|
if self.newline_token: |
|
|
prefix_to_tokens[token_prefix].add(self.newline_token) |
|
|
|
|
|
if self.debug: |
|
|
logger.debug(f"Built language prefix tree with {len(prefix_to_tokens)} token sequence prefixes") |
|
|
empty_prefix = tuple() |
|
|
if empty_prefix in prefix_to_tokens: |
|
|
first_tokens = prefix_to_tokens[empty_prefix] |
|
|
decoded_first = [(t, repr(self.tokenizer.decode([t]))) for t in sorted(first_tokens)] |
|
|
logger.debug(f"First tokens allowed for language (empty prefix): {decoded_first}") |
|
|
|
|
|
return prefix_to_tokens |
|
|
|
|
|
def diagnose_keyscale_prefix_tree(self): |
|
|
""" |
|
|
Diagnose the keyscale prefix tree to help debug generation bias. |
|
|
Call this method to print detailed information about allowed tokens at each prefix. |
|
|
""" |
|
|
print("=" * 60) |
|
|
print("KEYSCALE PREFIX TREE DIAGNOSIS") |
|
|
print("=" * 60) |
|
|
|
|
|
|
|
|
if "" in self.keyscale_prefix_tree: |
|
|
first_tokens = self.keyscale_prefix_tree[""] |
|
|
print(f"\n[Empty prefix] Allowed first tokens ({len(first_tokens)} total):") |
|
|
for t in sorted(first_tokens): |
|
|
decoded = self.tokenizer.decode([t]) |
|
|
print(f" Token {t}: {repr(decoded)}") |
|
|
else: |
|
|
print("\nWARNING: Empty prefix not in tree!") |
|
|
|
|
|
|
|
|
test_prefixes = ["A", "B", "C", "D", "E", "F", "G"] |
|
|
for prefix in test_prefixes: |
|
|
|
|
|
for test_key in [prefix, prefix + " "]: |
|
|
if test_key in self.keyscale_prefix_tree: |
|
|
tokens = self.keyscale_prefix_tree[test_key] |
|
|
print(f"\n[Prefix {repr(test_key)}] Allowed tokens ({len(tokens)}):") |
|
|
for t in sorted(tokens): |
|
|
decoded = self.tokenizer.decode([t]) |
|
|
print(f" Token {t}: {repr(decoded)}") |
|
|
|
|
|
|
|
|
print(f"\n[Valid keyscales] Total: {len(self.valid_keyscales)}") |
|
|
sample = sorted(list(self.valid_keyscales))[:10] |
|
|
for ks in sample: |
|
|
print(f" {repr(ks)}") |
|
|
|
|
|
print("=" * 60) |
|
|
|
|
|
|
|
|
def _load_genres_vocab(self): |
|
|
""" |
|
|
Load genres vocabulary from file. Supports hot reload by checking file mtime. |
|
|
File format: one genre per line, lines starting with # are comments. |
|
|
""" |
|
|
if not os.path.exists(self.genres_vocab_path): |
|
|
if self.debug: |
|
|
logger.debug(f"Genres vocab file not found: {self.genres_vocab_path}") |
|
|
return |
|
|
|
|
|
try: |
|
|
mtime = os.path.getmtime(self.genres_vocab_path) |
|
|
if mtime <= self.genres_vocab_mtime: |
|
|
return |
|
|
|
|
|
with open(self.genres_vocab_path, 'r', encoding='utf-8') as f: |
|
|
genres = [] |
|
|
for line in f: |
|
|
line = line.strip() |
|
|
if line and not line.startswith('#'): |
|
|
genres.append(line.lower()) |
|
|
|
|
|
self.genres_vocab = genres |
|
|
self.genres_vocab_mtime = mtime |
|
|
self._build_genres_trie() |
|
|
|
|
|
if self.debug: |
|
|
logger.debug(f"Loaded {len(self.genres_vocab)} genres from {self.genres_vocab_path}") |
|
|
except Exception as e: |
|
|
logger.warning(f"Failed to load genres vocab: {e}") |
|
|
|
|
|
def _build_genres_trie(self): |
|
|
""" |
|
|
Build a trie (prefix tree) from genres vocabulary for efficient prefix matching. |
|
|
Each node is a dict with: |
|
|
- '_end': True if this node represents a complete genre |
|
|
- other keys: next characters in the trie |
|
|
""" |
|
|
self.genres_trie = {} |
|
|
|
|
|
for genre in self.genres_vocab: |
|
|
node = self.genres_trie |
|
|
for char in genre: |
|
|
if char not in node: |
|
|
node[char] = {} |
|
|
node = node[char] |
|
|
node['_end'] = True |
|
|
|
|
|
if self.debug: |
|
|
logger.debug(f"Built genres trie with {len(self.genres_vocab)} entries") |
|
|
|
|
|
def _extract_caption_genres(self, caption: str): |
|
|
""" |
|
|
Extract genres from the user's caption that match entries in the vocabulary. |
|
|
This creates a smaller trie for faster and more relevant genre generation. |
|
|
|
|
|
Strategy (optimized - O(words * max_genre_len) instead of O(vocab_size)): |
|
|
1. Extract words/phrases from caption |
|
|
2. For each word, use trie to find all vocab entries that START with this word |
|
|
3. Build a separate trie from matched genres |
|
|
""" |
|
|
if not caption or not self.genres_vocab: |
|
|
return |
|
|
|
|
|
caption_lower = caption.lower() |
|
|
matched_genres = set() |
|
|
|
|
|
|
|
|
import re |
|
|
words = re.split(r'[,\s\-_/\\|]+', caption_lower) |
|
|
words = [w.strip() for w in words if w.strip() and len(w.strip()) >= 2] |
|
|
|
|
|
|
|
|
for word in words: |
|
|
|
|
|
node = self._get_genres_trie_node(word) |
|
|
if node is not None: |
|
|
|
|
|
self._collect_complete_genres(node, word, matched_genres) |
|
|
|
|
|
|
|
|
|
|
|
genres_set = set(self.genres_vocab) |
|
|
for word in words: |
|
|
if word in genres_set: |
|
|
matched_genres.add(word) |
|
|
|
|
|
if not matched_genres: |
|
|
if self.debug: |
|
|
logger.debug(f"No genres matched in caption, using full vocab") |
|
|
return |
|
|
|
|
|
|
|
|
self.caption_matched_genres = list(matched_genres) |
|
|
self.caption_genres_trie = {} |
|
|
|
|
|
for genre in matched_genres: |
|
|
node = self.caption_genres_trie |
|
|
for char in genre: |
|
|
if char not in node: |
|
|
node[char] = {} |
|
|
node = node[char] |
|
|
node['_end'] = True |
|
|
|
|
|
if self.debug: |
|
|
logger.debug(f"Matched {len(matched_genres)} genres from caption: {list(matched_genres)[:5]}...") |
|
|
|
|
|
def _collect_complete_genres(self, node: Dict, prefix: str, result: set, max_depth: int = 50): |
|
|
""" |
|
|
Recursively collect all complete genres under a trie node. |
|
|
Limited depth to avoid too many matches. |
|
|
""" |
|
|
if max_depth <= 0: |
|
|
return |
|
|
|
|
|
if node.get('_end', False): |
|
|
result.add(prefix) |
|
|
|
|
|
|
|
|
if len(result) >= 100: |
|
|
return |
|
|
|
|
|
for char, child_node in node.items(): |
|
|
if char not in ('_end', '_tokens'): |
|
|
self._collect_complete_genres(child_node, prefix + char, result, max_depth - 1) |
|
|
|
|
|
def _precompute_char_token_mapping(self): |
|
|
""" |
|
|
Precompute mapping from characters to token IDs and token decoded texts. |
|
|
This allows O(1) lookup instead of calling tokenizer.encode()/decode() at runtime. |
|
|
|
|
|
Time complexity: O(vocab_size) - runs once during initialization |
|
|
|
|
|
Note: Many subword tokenizers (like Qwen) add space prefixes to tokens. |
|
|
We need to handle both the raw first char and the first non-space char. |
|
|
""" |
|
|
self._char_to_tokens: Dict[str, set] = {} |
|
|
self._token_to_text: Dict[int, str] = {} |
|
|
|
|
|
|
|
|
for token_id in range(self.vocab_size): |
|
|
try: |
|
|
text = self.tokenizer.decode([token_id]) |
|
|
|
|
|
if not text: |
|
|
continue |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
text_lower = text.lower() |
|
|
if text_lower.strip(): |
|
|
normalized_text = text_lower.rstrip() |
|
|
else: |
|
|
normalized_text = " " |
|
|
self._token_to_text[token_id] = normalized_text |
|
|
|
|
|
|
|
|
first_char = text[0].lower() |
|
|
if first_char not in self._char_to_tokens: |
|
|
self._char_to_tokens[first_char] = set() |
|
|
self._char_to_tokens[first_char].add(token_id) |
|
|
|
|
|
|
|
|
|
|
|
stripped_text = text.lstrip() |
|
|
if stripped_text and stripped_text != text: |
|
|
first_nonspace_char = stripped_text[0].lower() |
|
|
if first_nonspace_char not in self._char_to_tokens: |
|
|
self._char_to_tokens[first_nonspace_char] = set() |
|
|
self._char_to_tokens[first_nonspace_char].add(token_id) |
|
|
|
|
|
except Exception: |
|
|
continue |
|
|
|
|
|
if self.debug: |
|
|
logger.debug(f"Precomputed char->token mapping for {len(self._char_to_tokens)} unique characters") |
|
|
|
|
|
def _try_reload_genres_vocab(self): |
|
|
"""Check if genres vocab file has been updated and reload if necessary.""" |
|
|
if not os.path.exists(self.genres_vocab_path): |
|
|
return |
|
|
|
|
|
try: |
|
|
mtime = os.path.getmtime(self.genres_vocab_path) |
|
|
if mtime > self.genres_vocab_mtime: |
|
|
self._load_genres_vocab() |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
def _get_genres_trie_node(self, prefix: str) -> Optional[Dict]: |
|
|
""" |
|
|
Get the trie node for a given prefix. |
|
|
Returns None if the prefix is not valid (no genres start with this prefix). |
|
|
""" |
|
|
node = self.genres_trie |
|
|
for char in prefix.lower(): |
|
|
if char not in node: |
|
|
return None |
|
|
node = node[char] |
|
|
return node |
|
|
|
|
|
def _is_complete_genre(self, text: str) -> bool: |
|
|
"""Check if the given text is a complete genre in the vocabulary.""" |
|
|
node = self._get_genres_trie_node(text.strip()) |
|
|
return node is not None and node.get('_end', False) |
|
|
|
|
|
def _get_trie_node_from_trie(self, trie: Dict, prefix: str) -> Optional[Dict]: |
|
|
"""Get a trie node from a specific trie (helper for caption vs full trie).""" |
|
|
node = trie |
|
|
for char in prefix.lower(): |
|
|
if char not in node: |
|
|
return None |
|
|
node = node[char] |
|
|
return node |
|
|
|
|
|
def _get_allowed_genres_tokens(self) -> List[int]: |
|
|
""" |
|
|
Get allowed tokens for genres field based on trie matching. |
|
|
|
|
|
The entire genres string (including commas) must match a complete entry in the vocab. |
|
|
For example, if vocab contains "pop, rock, jazz", the generated string must exactly |
|
|
match that entry - we don't treat commas as separators for individual genres. |
|
|
|
|
|
Strategy: |
|
|
1. If caption-matched genres exist, use that smaller trie first (faster + more relevant) |
|
|
2. If no caption matches or prefix not in caption trie, fallback to full vocab trie |
|
|
3. Get valid next characters from current trie node |
|
|
4. For each candidate token, verify the full decoded text forms a valid trie prefix |
|
|
""" |
|
|
if not self.genres_vocab: |
|
|
|
|
|
return [] |
|
|
|
|
|
|
|
|
accumulated = self.accumulated_value.lower() |
|
|
current_genre_prefix = accumulated.strip() |
|
|
|
|
|
|
|
|
use_caption_trie = False |
|
|
current_node = None |
|
|
|
|
|
|
|
|
if self.caption_genres_trie: |
|
|
if current_genre_prefix == "": |
|
|
current_node = self.caption_genres_trie |
|
|
use_caption_trie = True |
|
|
else: |
|
|
current_node = self._get_trie_node_from_trie(self.caption_genres_trie, current_genre_prefix) |
|
|
if current_node is not None: |
|
|
use_caption_trie = True |
|
|
|
|
|
|
|
|
if current_node is None: |
|
|
if current_genre_prefix == "": |
|
|
current_node = self.genres_trie |
|
|
else: |
|
|
current_node = self._get_genres_trie_node(current_genre_prefix) |
|
|
|
|
|
if current_node is None: |
|
|
|
|
|
if self.newline_token: |
|
|
return [self.newline_token] |
|
|
return [] |
|
|
|
|
|
|
|
|
valid_next_chars = set(k for k in current_node.keys() if k not in ('_end', '_tokens')) |
|
|
|
|
|
|
|
|
is_complete = current_node.get('_end', False) |
|
|
|
|
|
if not valid_next_chars: |
|
|
|
|
|
allowed = set() |
|
|
if is_complete and self.newline_token: |
|
|
allowed.add(self.newline_token) |
|
|
return list(allowed) |
|
|
|
|
|
|
|
|
candidate_tokens = set() |
|
|
for char in valid_next_chars: |
|
|
if char in self._char_to_tokens: |
|
|
candidate_tokens.update(self._char_to_tokens[char]) |
|
|
|
|
|
|
|
|
active_trie = self.caption_genres_trie if use_caption_trie else self.genres_trie |
|
|
|
|
|
|
|
|
allowed = set() |
|
|
for token_id in candidate_tokens: |
|
|
|
|
|
decoded_normalized = self._token_to_text.get(token_id, "") |
|
|
|
|
|
if not decoded_normalized or not decoded_normalized.strip(): |
|
|
|
|
|
if ' ' in valid_next_chars or ',' in valid_next_chars: |
|
|
allowed.add(token_id) |
|
|
continue |
|
|
|
|
|
|
|
|
|
|
|
if decoded_normalized.startswith(' ') or decoded_normalized.startswith(','): |
|
|
|
|
|
new_prefix = current_genre_prefix + decoded_normalized |
|
|
else: |
|
|
new_prefix = current_genre_prefix + decoded_normalized |
|
|
|
|
|
|
|
|
new_node = self._get_trie_node_from_trie(active_trie, new_prefix) |
|
|
if new_node is not None: |
|
|
allowed.add(token_id) |
|
|
|
|
|
|
|
|
if is_complete and self.newline_token: |
|
|
allowed.add(self.newline_token) |
|
|
|
|
|
return list(allowed) |
|
|
|
|
|
def reset(self): |
|
|
"""Reset the processor state for a new generation.""" |
|
|
self.state = FSMState.THINK_TAG |
|
|
self.position_in_state = 0 |
|
|
self.accumulated_value = "" |
|
|
self.accumulated_token_ids = [] |
|
|
self.codes_count = 0 |
|
|
self.user_field_token_queue = [] |
|
|
self.current_user_field = None |
|
|
self.caption_after_newline = False |
|
|
self.caption_token_count = 0 |
|
|
self.caption_ending = False |
|
|
self.pending_field_name = "" |
|
|
|
|
|
def set_target_duration(self, duration: Optional[float]): |
|
|
""" |
|
|
Set the target duration for codes generation. |
|
|
|
|
|
Args: |
|
|
duration: Target duration in seconds. If None, no duration constraint is applied. |
|
|
5 codes = 1 second, so target_codes = duration * 5. |
|
|
""" |
|
|
self.target_duration = duration |
|
|
if duration is not None and duration > 0: |
|
|
self.target_codes = int(duration * 5) |
|
|
if self.debug: |
|
|
logger.debug(f"Set target duration: {duration}s -> {self.target_codes} codes") |
|
|
else: |
|
|
self.target_codes = None |
|
|
if self.debug: |
|
|
logger.debug("Target duration cleared, no duration constraint") |
|
|
|
|
|
def set_max_duration(self, max_duration: int): |
|
|
""" |
|
|
Dynamically update the maximum allowed duration for constrained decoding. |
|
|
|
|
|
This method should be called when GPU configuration changes (e.g., LM initialization state changes). |
|
|
It rebuilds the duration prefix tree to constrain duration values to the new maximum. |
|
|
|
|
|
Args: |
|
|
max_duration: Maximum duration in seconds (e.g., 120 for 2 minutes, 360 for 6 minutes) |
|
|
""" |
|
|
if max_duration == self.max_duration: |
|
|
return |
|
|
|
|
|
old_max = self.max_duration |
|
|
self.max_duration = max_duration |
|
|
|
|
|
|
|
|
self.field_specs["duration"]["max"] = max_duration |
|
|
|
|
|
|
|
|
self.valid_duration_values = [str(v) for v in range(self.field_specs["duration"]["min"], self.field_specs["duration"]["max"] + 1)] |
|
|
|
|
|
|
|
|
self.duration_prefix_tree = self._build_numeric_prefix_tree( |
|
|
self.valid_duration_values, |
|
|
context_prefix_for_matching="duration:", |
|
|
context_prefix_for_tokenization="duration: " |
|
|
) |
|
|
|
|
|
if self.debug: |
|
|
logger.debug(f"Updated max duration: {old_max}s -> {max_duration}s, rebuilt prefix tree with {len(self.valid_duration_values)} values") |
|
|
|
|
|
def _get_allowed_tokens_for_fixed_string(self, fixed_str: str) -> List[int]: |
|
|
""" |
|
|
Get the token IDs that can continue the fixed string from current position. |
|
|
Returns list of allowed token IDs. |
|
|
|
|
|
Strategy: Find the longest prefix that encodes to a single token, and return that token. |
|
|
This ensures we generate by tokens, not character-by-character. |
|
|
""" |
|
|
remaining = fixed_str[self.position_in_state:] |
|
|
if not remaining: |
|
|
return [] |
|
|
|
|
|
if self.debug: |
|
|
logger.debug(f"_get_allowed_tokens_for_fixed_string: fixed_str={repr(fixed_str)}, position_in_state={self.position_in_state}, remaining={repr(remaining)}") |
|
|
|
|
|
|
|
|
|
|
|
best_token = None |
|
|
best_prefix_len = 0 |
|
|
|
|
|
|
|
|
for end in range(len(remaining), 0, -1): |
|
|
prefix = remaining[:end] |
|
|
tokens = self.tokenizer.encode(prefix, add_special_tokens=False) |
|
|
if tokens and len(tokens) == 1: |
|
|
|
|
|
|
|
|
best_token = tokens[0] |
|
|
best_prefix_len = end |
|
|
if self.debug: |
|
|
logger.debug(f"Found single-token match: prefix={repr(prefix)}, token_id={best_token}, token_text={repr(self.tokenizer.decode([best_token]))}") |
|
|
break |
|
|
|
|
|
|
|
|
if best_token is not None: |
|
|
return [best_token] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
allowed_tokens = {} |
|
|
for end in range(1, min(len(remaining) + 1, 20)): |
|
|
prefix = remaining[:end] |
|
|
tokens = self.tokenizer.encode(prefix, add_special_tokens=False) |
|
|
if tokens: |
|
|
first_token = tokens[0] |
|
|
|
|
|
decoded_token = self.tokenizer.decode([first_token]) |
|
|
|
|
|
normalized_prefix = prefix.lstrip().lower() |
|
|
normalized_decoded = decoded_token.lstrip().lower() |
|
|
|
|
|
|
|
|
if normalized_decoded.startswith(normalized_prefix) or normalized_prefix.startswith(normalized_decoded): |
|
|
|
|
|
if first_token not in allowed_tokens or end > allowed_tokens[first_token]: |
|
|
allowed_tokens[first_token] = end |
|
|
|
|
|
|
|
|
|
|
|
sorted_tokens = sorted(allowed_tokens.items(), key=lambda x: x[1], reverse=True) |
|
|
result = [token for token, _ in sorted_tokens] if sorted_tokens else [] |
|
|
|
|
|
if self.debug: |
|
|
logger.debug(f"Fallback: returning {len(result)} tokens: {[(t, repr(self.tokenizer.decode([t]))) for t in result[:5]]}") |
|
|
if result: |
|
|
logger.debug(f"Fixed string: {repr(fixed_str)}, position: {self.position_in_state}, remaining: {repr(remaining)}") |
|
|
|
|
|
return result |
|
|
|
|
|
def _get_allowed_digit_tokens(self, min_val: int, max_val: int) -> List[int]: |
|
|
""" |
|
|
Get allowed digit tokens based on accumulated value and range constraints. |
|
|
Uses early-blocking to prevent out-of-range values. |
|
|
""" |
|
|
if not self.accumulated_value: |
|
|
|
|
|
allowed_digits = set() |
|
|
for v in range(min_val, max_val + 1): |
|
|
allowed_digits.add(int(str(v)[0])) |
|
|
return [self.digit_tokens[d] for d in allowed_digits if d in self.digit_tokens] |
|
|
|
|
|
current = int(self.accumulated_value) |
|
|
allowed = [] |
|
|
|
|
|
for d in range(10): |
|
|
new_value = int(self.accumulated_value + str(d)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if new_value > max_val: |
|
|
continue |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if new_value >= min_val: |
|
|
allowed.append(d) |
|
|
elif new_value * 10 <= max_val: |
|
|
|
|
|
allowed.append(d) |
|
|
|
|
|
return [self.digit_tokens[d] for d in allowed if d in self.digit_tokens] |
|
|
|
|
|
def _get_allowed_numeric_tokens(self, prefix_tree: Dict[Tuple[int, ...], Set[int]]) -> List[int]: |
|
|
""" |
|
|
Get allowed tokens for numeric field using the precomputed prefix tree. |
|
|
|
|
|
IMPORTANT: Uses token ID sequence as key (not string) to avoid tokenization mismatches. |
|
|
|
|
|
Args: |
|
|
prefix_tree: Precomputed prefix tree mapping token ID sequence -> set of allowed token IDs |
|
|
|
|
|
Returns: |
|
|
List of allowed token IDs for current accumulated_token_ids |
|
|
""" |
|
|
token_prefix = tuple(self.accumulated_token_ids) |
|
|
|
|
|
if token_prefix in prefix_tree: |
|
|
return list(prefix_tree[token_prefix]) |
|
|
|
|
|
|
|
|
|
|
|
return [] |
|
|
|
|
|
def _should_end_numeric_field(self, logits: torch.Tensor, min_val: int, max_val: int) -> bool: |
|
|
""" |
|
|
Determine if we should end the current numeric field. |
|
|
Returns True if P(newline) > P(any valid digit) AND current value is valid. |
|
|
""" |
|
|
if not self.accumulated_value: |
|
|
return False |
|
|
|
|
|
current = int(self.accumulated_value) |
|
|
if current < min_val or current > max_val: |
|
|
return False |
|
|
|
|
|
|
|
|
probs = torch.softmax(logits, dim=-1) |
|
|
|
|
|
newline_prob = probs[0, self.newline_token].item() if self.newline_token else 0 |
|
|
|
|
|
|
|
|
allowed_digits = self._get_allowed_digit_tokens(min_val, max_val) |
|
|
if not allowed_digits: |
|
|
return True |
|
|
|
|
|
max_digit_prob = max(probs[0, t].item() for t in allowed_digits) |
|
|
|
|
|
if self.debug: |
|
|
logger.debug(f"Numeric field decision: newline_prob={newline_prob:.4f}, max_digit_prob={max_digit_prob:.4f}") |
|
|
|
|
|
return newline_prob > max_digit_prob |
|
|
|
|
|
|
|
|
def _should_end_text_field(self, logits: torch.Tensor) -> bool: |
|
|
""" |
|
|
Determine if we should end a text field (genres). |
|
|
Returns True if P(newline) > P(any other token) AND we have some content. |
|
|
""" |
|
|
if not self.accumulated_value.strip(): |
|
|
return False |
|
|
|
|
|
probs = torch.softmax(logits, dim=-1) |
|
|
newline_prob = probs[0, self.newline_token].item() if self.newline_token else 0 |
|
|
|
|
|
|
|
|
masked_probs = probs.clone() |
|
|
if self.newline_token: |
|
|
masked_probs[0, self.newline_token] = 0 |
|
|
max_other_prob = masked_probs[0].max().item() |
|
|
|
|
|
return newline_prob > max_other_prob |
|
|
|
|
|
def _get_allowed_keyscale_tokens(self) -> List[int]: |
|
|
""" |
|
|
Get allowed tokens for keyscale field using the precomputed prefix tree. |
|
|
Uses token ID sequence as key (not string) to avoid tokenization mismatches. |
|
|
""" |
|
|
|
|
|
token_prefix = tuple(self.accumulated_token_ids) |
|
|
|
|
|
if token_prefix in self.keyscale_prefix_tree: |
|
|
return list(self.keyscale_prefix_tree[token_prefix]) |
|
|
|
|
|
|
|
|
|
|
|
return [] |
|
|
|
|
|
def _is_keyscale_complete(self) -> bool: |
|
|
""" |
|
|
Check if keyscale value is complete and valid. |
|
|
Uses token ID sequence to check if current prefix allows newline. |
|
|
""" |
|
|
token_prefix = tuple(self.accumulated_token_ids) |
|
|
|
|
|
if token_prefix in self.keyscale_prefix_tree: |
|
|
return self.newline_token in self.keyscale_prefix_tree[token_prefix] |
|
|
return False |
|
|
|
|
|
def _get_allowed_language_tokens(self) -> List[int]: |
|
|
""" |
|
|
Get allowed tokens for language field using the precomputed prefix tree. |
|
|
Uses token ID sequence as key (not string) to avoid tokenization mismatches. |
|
|
Similar to keyscale. |
|
|
""" |
|
|
token_prefix = tuple(self.accumulated_token_ids) |
|
|
|
|
|
if token_prefix in self.language_prefix_tree: |
|
|
return list(self.language_prefix_tree[token_prefix]) |
|
|
|
|
|
|
|
|
return [] |
|
|
|
|
|
def _get_allowed_timesig_tokens(self) -> List[int]: |
|
|
""" |
|
|
Get allowed tokens for timesignature field using the precomputed prefix tree. |
|
|
Uses token ID sequence as key (not string) to avoid tokenization mismatches. |
|
|
""" |
|
|
token_prefix = tuple(self.accumulated_token_ids) |
|
|
|
|
|
if token_prefix in self.timesig_prefix_tree: |
|
|
return list(self.timesig_prefix_tree[token_prefix]) |
|
|
|
|
|
|
|
|
|
|
|
return [] |
|
|
|
|
|
def __call__( |
|
|
self, |
|
|
input_ids: torch.LongTensor, |
|
|
scores: torch.FloatTensor, |
|
|
) -> torch.FloatTensor: |
|
|
""" |
|
|
Apply constrained decoding by modifying logits. |
|
|
|
|
|
Args: |
|
|
input_ids: [batch_size, seq_len] input token IDs |
|
|
scores: [batch_size, vocab_size] logits for next token |
|
|
|
|
|
Returns: |
|
|
Modified scores with invalid tokens masked to -inf and temperature scaling applied |
|
|
""" |
|
|
if not self.enabled: |
|
|
return self._apply_temperature_scaling(scores) |
|
|
|
|
|
if self.state == FSMState.COMPLETED: |
|
|
|
|
|
if self.generation_phase == "understand" and self.audio_code_mask is not None: |
|
|
|
|
|
if self.audio_code_mask.device != scores.device or self.audio_code_mask.dtype != scores.dtype: |
|
|
self.audio_code_mask = self.audio_code_mask.to(device=scores.device, dtype=scores.dtype) |
|
|
scores = scores + self.audio_code_mask |
|
|
return self._apply_temperature_scaling(scores) |
|
|
|
|
|
|
|
|
if self.generation_phase == "codes" and self.state == FSMState.THINK_TAG: |
|
|
|
|
|
if self._input_contains_think_end_tag(input_ids): |
|
|
|
|
|
self.state = FSMState.CODES_GENERATION |
|
|
self.codes_count = 0 |
|
|
if self.debug: |
|
|
logger.debug("Codes phase: detected </think> in input, skipping to CODES_GENERATION") |
|
|
|
|
|
if self.state == FSMState.CODES_GENERATION: |
|
|
|
|
|
|
|
|
|
|
|
if self.non_audio_code_mask is not None: |
|
|
|
|
|
if self.non_audio_code_mask.device != scores.device or self.non_audio_code_mask.dtype != scores.dtype: |
|
|
self.non_audio_code_mask = self.non_audio_code_mask.to(device=scores.device, dtype=scores.dtype) |
|
|
scores = scores + self.non_audio_code_mask |
|
|
|
|
|
|
|
|
if self.target_codes is not None and self.eos_token_id is not None: |
|
|
if self.codes_count < self.target_codes: |
|
|
|
|
|
scores[:, self.eos_token_id] = float('-inf') |
|
|
if self.debug: |
|
|
logger.debug(f"Codes generation: {self.codes_count}/{self.target_codes}, blocking EOS") |
|
|
else: |
|
|
|
|
|
eos_scores = scores[:, self.eos_token_id].clone() |
|
|
scores.fill_(float('-inf')) |
|
|
scores[:, self.eos_token_id] = eos_scores |
|
|
if self.debug: |
|
|
logger.debug(f"Codes generation: {self.codes_count}/{self.target_codes}, forcing EOS") |
|
|
return self._apply_temperature_scaling(scores) |
|
|
|
|
|
batch_size = scores.shape[0] |
|
|
|
|
|
|
|
|
for b in range(batch_size): |
|
|
result = self._process_single_sequence(input_ids[b], scores[b:b+1]) |
|
|
scores[b] = result[0] |
|
|
|
|
|
|
|
|
return self._apply_temperature_scaling(scores) |
|
|
|
|
|
def _input_contains_think_end_tag(self, input_ids: torch.LongTensor) -> bool: |
|
|
""" |
|
|
Check if input contains the </think> closing tag. |
|
|
|
|
|
Args: |
|
|
input_ids: [batch_size, seq_len] input token IDs |
|
|
|
|
|
Returns: |
|
|
True if </think> is found in the input (any sequence in batch) |
|
|
""" |
|
|
|
|
|
think_end_tokens = self.tokenizer.encode("</think>", add_special_tokens=False) |
|
|
if not think_end_tokens: |
|
|
return False |
|
|
|
|
|
|
|
|
for b in range(input_ids.shape[0]): |
|
|
seq = input_ids[b].tolist() |
|
|
|
|
|
for i in range(len(seq) - len(think_end_tokens) + 1): |
|
|
if seq[i:i+len(think_end_tokens)] == think_end_tokens: |
|
|
return True |
|
|
|
|
|
return False |
|
|
|
|
|
def _apply_temperature_scaling(self, scores: torch.FloatTensor) -> torch.FloatTensor: |
|
|
""" |
|
|
Apply temperature scaling based on current generation phase. |
|
|
|
|
|
Temperature scaling: logits = logits / temperature |
|
|
- Lower temperature (< 1.0) makes distribution sharper (more deterministic) |
|
|
- Higher temperature (> 1.0) makes distribution flatter (more diverse) |
|
|
|
|
|
Args: |
|
|
scores: [batch_size, vocab_size] logits |
|
|
|
|
|
Returns: |
|
|
Temperature-scaled logits |
|
|
""" |
|
|
|
|
|
if self.state == FSMState.CODES_GENERATION or self.state == FSMState.COMPLETED: |
|
|
temperature = self.codes_temperature |
|
|
else: |
|
|
temperature = self.metadata_temperature |
|
|
|
|
|
|
|
|
if temperature is None: |
|
|
return scores |
|
|
|
|
|
|
|
|
if temperature <= 0: |
|
|
temperature = 1e-6 |
|
|
|
|
|
|
|
|
return scores / temperature |
|
|
|
|
|
def _get_user_provided_field_tokens(self, field_name: str) -> Optional[List[int]]: |
|
|
""" |
|
|
Get token sequence for a user-provided field (field_name + value + newline). |
|
|
Uses the same tokenization logic as prefix tree building. |
|
|
|
|
|
Args: |
|
|
field_name: Field name ("bpm", "caption", "duration", "keyscale", "language", "timesignature") |
|
|
|
|
|
Returns: |
|
|
List of token IDs for the complete field, or None if field is not provided |
|
|
""" |
|
|
value = self.user_provided_metadata.get(field_name) |
|
|
if value is None: |
|
|
return None |
|
|
|
|
|
|
|
|
field_to_prefix = { |
|
|
"bpm": "bpm: ", |
|
|
"caption": "caption: ", |
|
|
"duration": "duration: ", |
|
|
"keyscale": "keyscale: ", |
|
|
"language": "language: ", |
|
|
"timesignature": "timesignature: ", |
|
|
"genres": "genres: ", |
|
|
} |
|
|
prefix = field_to_prefix[field_name] |
|
|
full_text = f"{prefix}{value}\n" |
|
|
|
|
|
|
|
|
tokens = self.tokenizer.encode(full_text, add_special_tokens=False) |
|
|
|
|
|
|
|
|
|
|
|
prefix_for_matching = field_name + ":" |
|
|
prefix_tokens = self.tokenizer.encode(prefix_for_matching, add_special_tokens=False) |
|
|
|
|
|
|
|
|
if len(tokens) >= len(prefix_tokens) and tokens[:len(prefix_tokens)] == prefix_tokens: |
|
|
|
|
|
return tokens[len(prefix_tokens):] |
|
|
else: |
|
|
|
|
|
if self.debug: |
|
|
logger.warning(f"Could not match prefix tokens for field {field_name}, using all tokens") |
|
|
return tokens |
|
|
|
|
|
def _process_single_sequence( |
|
|
self, |
|
|
input_ids: torch.LongTensor, |
|
|
scores: torch.FloatTensor, |
|
|
) -> torch.FloatTensor: |
|
|
"""Process a single sequence and return modified scores (inplace when possible).""" |
|
|
|
|
|
|
|
|
|
|
|
if self.user_field_token_queue: |
|
|
next_token = self.user_field_token_queue[0] |
|
|
self._apply_whitelist_inplace(scores, [next_token]) |
|
|
return scores |
|
|
|
|
|
if self.state in self.fixed_strings: |
|
|
|
|
|
fixed_str = self.fixed_strings[self.state] |
|
|
allowed = self._get_allowed_tokens_for_fixed_string(fixed_str) |
|
|
|
|
|
if allowed: |
|
|
|
|
|
|
|
|
if self.state == FSMState.THINK_END_TAG and self.stop_at_reasoning: |
|
|
|
|
|
remaining_chars = len(fixed_str) - self.position_in_state |
|
|
|
|
|
if remaining_chars <= 10: |
|
|
|
|
|
if self.eos_token_id is not None: |
|
|
self._apply_whitelist_inplace(scores, [self.eos_token_id]) |
|
|
if self.debug: |
|
|
logger.debug(f"stop_at_reasoning=True: forcing EOS near end of </think> tag (remaining: {remaining_chars} chars)") |
|
|
return scores |
|
|
|
|
|
|
|
|
self._apply_whitelist_inplace(scores, allowed) |
|
|
else: |
|
|
|
|
|
|
|
|
|
|
|
if self.state == FSMState.THINK_END_TAG and self.stop_at_reasoning: |
|
|
|
|
|
if self.eos_token_id is not None: |
|
|
self._apply_whitelist_inplace(scores, [self.eos_token_id]) |
|
|
if self.debug: |
|
|
logger.debug(f"stop_at_reasoning=True: forcing EOS after completing </think> tag") |
|
|
return scores |
|
|
|
|
|
old_state = self.state |
|
|
self._transition_to_next_state() |
|
|
|
|
|
if self.state in self.fixed_strings: |
|
|
|
|
|
if self.debug: |
|
|
logger.warning(f"State transition from {old_state.name} to {self.state.name} still in fixed_strings, avoiding recursion") |
|
|
return scores |
|
|
|
|
|
scores.zero_() |
|
|
return self._process_single_sequence(input_ids, scores) |
|
|
|
|
|
elif self.state == FSMState.BPM_VALUE: |
|
|
|
|
|
if self.user_provided_metadata["bpm"] is not None and not self.user_field_token_queue and not self.accumulated_token_ids: |
|
|
|
|
|
value = self.user_provided_metadata["bpm"] |
|
|
|
|
|
value_text = f" {value}\n" |
|
|
value_tokens = self.tokenizer.encode(value_text, add_special_tokens=False) |
|
|
if value_tokens: |
|
|
self.user_field_token_queue = value_tokens |
|
|
self.current_user_field = "bpm" |
|
|
|
|
|
self._apply_whitelist_inplace(scores, [value_tokens[0]]) |
|
|
return scores |
|
|
|
|
|
|
|
|
allowed = self._get_allowed_numeric_tokens(self.bpm_prefix_tree) |
|
|
|
|
|
|
|
|
token_prefix = tuple(self.accumulated_token_ids) |
|
|
if token_prefix in self.bpm_prefix_tree and self.newline_token in self.bpm_prefix_tree[token_prefix]: |
|
|
allowed = allowed + [self.newline_token] |
|
|
|
|
|
self._apply_whitelist_inplace(scores, allowed) |
|
|
|
|
|
elif self.state == FSMState.CAPTION_VALUE: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.user_provided_metadata["caption"] is not None and not self.user_field_token_queue and not self.accumulated_value: |
|
|
|
|
|
value = self.user_provided_metadata["caption"] |
|
|
value_text = f" {value}\n" |
|
|
value_tokens = self.tokenizer.encode(value_text, add_special_tokens=False) |
|
|
if value_tokens: |
|
|
self.user_field_token_queue = value_tokens |
|
|
self.current_user_field = "caption" |
|
|
|
|
|
self._apply_whitelist_inplace(scores, [value_tokens[0]]) |
|
|
return scores |
|
|
|
|
|
|
|
|
if self.caption_after_newline: |
|
|
|
|
|
top_token_id = torch.argmax(scores[0]).item() |
|
|
top_token_text = self.tokenizer.decode([top_token_id]) |
|
|
|
|
|
|
|
|
if len(top_token_text) > 0 and top_token_text[0] not in ' \t': |
|
|
|
|
|
|
|
|
|
|
|
self.caption_after_newline = False |
|
|
self.caption_ending = True |
|
|
self.pending_field_name = "" |
|
|
|
|
|
return scores |
|
|
else: |
|
|
|
|
|
self.caption_after_newline = False |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.caption_ending: |
|
|
|
|
|
|
|
|
return scores |
|
|
|
|
|
|
|
|
if self.backtick_token is not None: |
|
|
scores[0, self.backtick_token] = float('-inf') |
|
|
|
|
|
|
|
|
|
|
|
if self.audio_code_mask is not None: |
|
|
|
|
|
if self.audio_code_mask.device != scores.device or self.audio_code_mask.dtype != scores.dtype: |
|
|
self.audio_code_mask = self.audio_code_mask.to(device=scores.device, dtype=scores.dtype) |
|
|
scores = scores + self.audio_code_mask |
|
|
|
|
|
|
|
|
if self.caption_token_count >= 512: |
|
|
|
|
|
if self.newline_token is not None: |
|
|
self._apply_whitelist_inplace(scores, [self.newline_token]) |
|
|
return scores |
|
|
|
|
|
|
|
|
return scores |
|
|
|
|
|
elif self.state == FSMState.DURATION_VALUE: |
|
|
|
|
|
if self.user_provided_metadata["duration"] is not None and not self.user_field_token_queue and not self.accumulated_token_ids: |
|
|
|
|
|
value = self.user_provided_metadata["duration"] |
|
|
value_text = f" {value}\n" |
|
|
value_tokens = self.tokenizer.encode(value_text, add_special_tokens=False) |
|
|
if value_tokens: |
|
|
self.user_field_token_queue = value_tokens |
|
|
self.current_user_field = "duration" |
|
|
|
|
|
self._apply_whitelist_inplace(scores, [value_tokens[0]]) |
|
|
return scores |
|
|
|
|
|
|
|
|
if self.target_duration is not None: |
|
|
target_str = str(int(self.target_duration)) |
|
|
current_pos = len(self.accumulated_value) |
|
|
|
|
|
if current_pos < len(target_str): |
|
|
|
|
|
next_digit = int(target_str[current_pos]) |
|
|
if next_digit in self.digit_tokens: |
|
|
self._apply_whitelist_inplace(scores, [self.digit_tokens[next_digit]]) |
|
|
else: |
|
|
|
|
|
if self.newline_token: |
|
|
self._apply_whitelist_inplace(scores, [self.newline_token]) |
|
|
else: |
|
|
|
|
|
|
|
|
allowed = self._get_allowed_numeric_tokens(self.duration_prefix_tree) |
|
|
|
|
|
|
|
|
token_prefix = tuple(self.accumulated_token_ids) |
|
|
if token_prefix in self.duration_prefix_tree and self.newline_token in self.duration_prefix_tree[token_prefix]: |
|
|
allowed = allowed + [self.newline_token] |
|
|
|
|
|
self._apply_whitelist_inplace(scores, allowed) |
|
|
|
|
|
elif self.state == FSMState.GENRES_VALUE: |
|
|
|
|
|
if self.user_provided_metadata["genres"] is not None and not self.user_field_token_queue and not self.accumulated_value: |
|
|
|
|
|
value = self.user_provided_metadata["genres"] |
|
|
value_text = f" {value}\n" |
|
|
value_tokens = self.tokenizer.encode(value_text, add_special_tokens=False) |
|
|
if value_tokens: |
|
|
self.user_field_token_queue = value_tokens |
|
|
self.current_user_field = "genres" |
|
|
|
|
|
self._apply_whitelist_inplace(scores, [value_tokens[0]]) |
|
|
return scores |
|
|
|
|
|
|
|
|
self._try_reload_genres_vocab() |
|
|
|
|
|
|
|
|
allowed = self._get_allowed_genres_tokens() |
|
|
|
|
|
if allowed: |
|
|
|
|
|
self._apply_whitelist_inplace(scores, allowed) |
|
|
elif self.genres_vocab: |
|
|
|
|
|
|
|
|
if self.newline_token: |
|
|
if self.debug: |
|
|
logger.debug(f"No valid genre continuation for '{self.accumulated_value}', forcing newline") |
|
|
self._apply_whitelist_inplace(scores, [self.newline_token]) |
|
|
else: |
|
|
|
|
|
if self._should_end_text_field(scores): |
|
|
if self.newline_token: |
|
|
self._apply_whitelist_inplace(scores, [self.newline_token]) |
|
|
self._transition_to_next_state() |
|
|
else: |
|
|
|
|
|
if not self.accumulated_value.strip(): |
|
|
if self.newline_token: |
|
|
scores[0, self.newline_token] = float('-inf') |
|
|
|
|
|
|
|
|
elif self.state == FSMState.KEYSCALE_VALUE: |
|
|
|
|
|
if self.user_provided_metadata["keyscale"] is not None and not self.user_field_token_queue and not self.accumulated_token_ids: |
|
|
|
|
|
value = self.user_provided_metadata["keyscale"] |
|
|
value_text = f" {value}\n" |
|
|
value_tokens = self.tokenizer.encode(value_text, add_special_tokens=False) |
|
|
if value_tokens: |
|
|
self.user_field_token_queue = value_tokens |
|
|
self.current_user_field = "keyscale" |
|
|
|
|
|
self._apply_whitelist_inplace(scores, [value_tokens[0]]) |
|
|
return scores |
|
|
|
|
|
|
|
|
token_prefix = tuple(self.accumulated_token_ids) |
|
|
if token_prefix in self.keyscale_prefix_tree and self.newline_token in self.keyscale_prefix_tree[token_prefix]: |
|
|
|
|
|
if self.newline_token: |
|
|
self._apply_whitelist_inplace(scores, [self.newline_token]) |
|
|
else: |
|
|
|
|
|
allowed = self._get_allowed_keyscale_tokens() |
|
|
if allowed: |
|
|
self._apply_whitelist_inplace(scores, allowed) |
|
|
else: |
|
|
|
|
|
|
|
|
if self.newline_token: |
|
|
self._apply_whitelist_inplace(scores, [self.newline_token]) |
|
|
|
|
|
elif self.state == FSMState.LANGUAGE_VALUE: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.user_provided_metadata["language"] is not None and not self.user_field_token_queue and not self.accumulated_token_ids: |
|
|
|
|
|
value = self.user_provided_metadata["language"] |
|
|
value_text = f" {value}\n" |
|
|
value_tokens = self.tokenizer.encode(value_text, add_special_tokens=False) |
|
|
if value_tokens: |
|
|
self.user_field_token_queue = value_tokens |
|
|
self.current_user_field = "language" |
|
|
|
|
|
self._apply_whitelist_inplace(scores, [value_tokens[0]]) |
|
|
return scores |
|
|
|
|
|
|
|
|
|
|
|
if not self.accumulated_token_ids: |
|
|
|
|
|
empty_prefix = tuple() |
|
|
if empty_prefix in self.language_prefix_tree: |
|
|
candidate_tokens = list(self.language_prefix_tree[empty_prefix]) |
|
|
|
|
|
if candidate_tokens: |
|
|
|
|
|
|
|
|
candidate_indices = torch.tensor(candidate_tokens, device=scores.device, dtype=torch.long) |
|
|
candidate_scores = scores[0, candidate_indices] |
|
|
|
|
|
|
|
|
best_idx = torch.argmax(candidate_scores).item() |
|
|
top_token_id = candidate_tokens[best_idx] |
|
|
|
|
|
|
|
|
self._apply_whitelist_inplace(scores, [top_token_id]) |
|
|
|
|
|
if self.debug: |
|
|
top_token_text = self.tokenizer.decode([top_token_id]) |
|
|
logger.debug(f"Language field: selected top-1 token {top_token_id} ({repr(top_token_text)}) from {len(candidate_tokens)} candidates") |
|
|
else: |
|
|
|
|
|
if self.newline_token: |
|
|
self._apply_whitelist_inplace(scores, [self.newline_token]) |
|
|
else: |
|
|
|
|
|
if self.newline_token: |
|
|
self._apply_whitelist_inplace(scores, [self.newline_token]) |
|
|
else: |
|
|
|
|
|
|
|
|
token_prefix = tuple(self.accumulated_token_ids) |
|
|
if token_prefix in self.language_prefix_tree and self.newline_token in self.language_prefix_tree[token_prefix]: |
|
|
|
|
|
if self.newline_token: |
|
|
self._apply_whitelist_inplace(scores, [self.newline_token]) |
|
|
else: |
|
|
|
|
|
allowed = self._get_allowed_language_tokens() |
|
|
if allowed: |
|
|
self._apply_whitelist_inplace(scores, allowed) |
|
|
else: |
|
|
|
|
|
if self.newline_token: |
|
|
self._apply_whitelist_inplace(scores, [self.newline_token]) |
|
|
|
|
|
elif self.state == FSMState.TIMESIG_VALUE: |
|
|
|
|
|
if self.user_provided_metadata["timesignature"] is not None and not self.user_field_token_queue and not self.accumulated_token_ids: |
|
|
|
|
|
value = self.user_provided_metadata["timesignature"] |
|
|
value_text = f" {value}\n" |
|
|
value_tokens = self.tokenizer.encode(value_text, add_special_tokens=False) |
|
|
if value_tokens: |
|
|
self.user_field_token_queue = value_tokens |
|
|
self.current_user_field = "timesignature" |
|
|
|
|
|
self._apply_whitelist_inplace(scores, [value_tokens[0]]) |
|
|
return scores |
|
|
|
|
|
|
|
|
token_prefix = tuple(self.accumulated_token_ids) |
|
|
if token_prefix in self.timesig_prefix_tree and self.newline_token in self.timesig_prefix_tree[token_prefix]: |
|
|
|
|
|
if self.newline_token: |
|
|
self._apply_whitelist_inplace(scores, [self.newline_token]) |
|
|
else: |
|
|
|
|
|
allowed = self._get_allowed_timesig_tokens() |
|
|
self._apply_whitelist_inplace(scores, allowed) |
|
|
|
|
|
return scores |
|
|
|
|
|
def _transition_to_next_state(self): |
|
|
"""Transition to the next FSM state.""" |
|
|
if self.state in self.next_state: |
|
|
old_state = self.state |
|
|
next_state = self.next_state[self.state] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if old_state == FSMState.THINK_END_TAG: |
|
|
if self.generation_phase == "understand": |
|
|
|
|
|
|
|
|
next_state = FSMState.COMPLETED |
|
|
if self.debug: |
|
|
logger.debug(f"generation_phase='understand': allowing free-form lyrics after </think>") |
|
|
|
|
|
|
|
|
self.state = next_state |
|
|
self.position_in_state = 0 |
|
|
self.accumulated_value = "" |
|
|
self.accumulated_token_ids = [] |
|
|
self.caption_after_newline = False |
|
|
self.caption_token_count = 0 |
|
|
self.caption_ending = False |
|
|
self.pending_field_name = "" |
|
|
if self.debug: |
|
|
logger.debug(f"FSM transition: {old_state.name} -> {self.state.name}") |
|
|
|
|
|
def update_state(self, generated_token_id: int): |
|
|
""" |
|
|
Update internal state after a token has been generated. |
|
|
This should be called after each token generation. |
|
|
|
|
|
Args: |
|
|
generated_token_id: The token ID that was just generated |
|
|
""" |
|
|
if not self.enabled: |
|
|
return |
|
|
|
|
|
if self.state == FSMState.COMPLETED: |
|
|
return |
|
|
|
|
|
if self.state == FSMState.CODES_GENERATION: |
|
|
|
|
|
self.codes_count += 1 |
|
|
if self.debug and self.target_codes is not None: |
|
|
logger.debug(f"Codes count: {self.codes_count}/{self.target_codes}") |
|
|
return |
|
|
|
|
|
|
|
|
if self.user_field_token_queue: |
|
|
|
|
|
expected_token = self.user_field_token_queue[0] |
|
|
if generated_token_id != expected_token: |
|
|
if self.debug: |
|
|
logger.warning(f"Expected token {expected_token} but got {generated_token_id} for user-provided field {self.current_user_field}") |
|
|
|
|
|
|
|
|
self.user_field_token_queue.pop(0) |
|
|
|
|
|
|
|
|
if not self.user_field_token_queue: |
|
|
if self.debug: |
|
|
logger.debug(f"Completed injection of user-provided field: {self.current_user_field}") |
|
|
field_name = self.current_user_field |
|
|
self.current_user_field = None |
|
|
|
|
|
|
|
|
|
|
|
next_state = self._get_next_field_state(field_name) |
|
|
if next_state: |
|
|
old_state = self.state |
|
|
self.state = next_state |
|
|
self.position_in_state = 0 |
|
|
self.accumulated_value = "" |
|
|
self.accumulated_token_ids = [] |
|
|
if self.debug: |
|
|
logger.debug(f"FSM transition (after user field injection): {old_state.name} -> {self.state.name}") |
|
|
else: |
|
|
|
|
|
self._transition_to_next_state() |
|
|
return |
|
|
|
|
|
token_str = self.tokenizer.decode([generated_token_id]) |
|
|
|
|
|
if self.debug: |
|
|
logger.debug(f"Generated token: {repr(token_str)} (id={generated_token_id}), state={self.state.name}") |
|
|
|
|
|
if self.state in self.fixed_strings: |
|
|
|
|
|
fixed_str = self.fixed_strings[self.state] |
|
|
self.position_in_state += len(token_str) |
|
|
|
|
|
|
|
|
if self.position_in_state >= len(fixed_str): |
|
|
|
|
|
if self.state == FSMState.THINK_END_TAG and self.stop_at_reasoning: |
|
|
|
|
|
|
|
|
old_state = self.state |
|
|
self.state = FSMState.COMPLETED |
|
|
self.position_in_state = 0 |
|
|
self.accumulated_value = "" |
|
|
self.accumulated_token_ids = [] |
|
|
if self.debug: |
|
|
logger.debug(f"FSM transition (stop_at_reasoning): {old_state.name} -> {self.state.name}") |
|
|
else: |
|
|
self._transition_to_next_state() |
|
|
|
|
|
elif self.state in [FSMState.BPM_VALUE, FSMState.DURATION_VALUE, FSMState.TIMESIG_VALUE]: |
|
|
|
|
|
if generated_token_id == self.newline_token: |
|
|
old_state = self.state |
|
|
self._transition_to_next_state() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.state in self.fixed_strings: |
|
|
return |
|
|
else: |
|
|
|
|
|
self.accumulated_token_ids.append(generated_token_id) |
|
|
|
|
|
if token_str.strip().isdigit(): |
|
|
self.accumulated_value += token_str.strip() |
|
|
|
|
|
elif self.state == FSMState.GENRES_VALUE: |
|
|
if generated_token_id == self.newline_token: |
|
|
|
|
|
self._transition_to_next_state() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.state in self.fixed_strings: |
|
|
return |
|
|
else: |
|
|
|
|
|
self.accumulated_value += token_str |
|
|
|
|
|
elif self.state == FSMState.CAPTION_VALUE: |
|
|
|
|
|
self.caption_token_count += 1 |
|
|
|
|
|
|
|
|
self.accumulated_value += token_str |
|
|
|
|
|
|
|
|
|
|
|
if '\n' in token_str: |
|
|
|
|
|
self.caption_after_newline = True |
|
|
else: |
|
|
|
|
|
|
|
|
self.caption_after_newline = False |
|
|
|
|
|
|
|
|
if self.caption_ending: |
|
|
self.pending_field_name += token_str |
|
|
|
|
|
|
|
|
if ':' in token_str or token_str.strip() == ':': |
|
|
|
|
|
field_name_full = self.pending_field_name.strip() |
|
|
|
|
|
field_name = field_name_full.rstrip(':').strip().lower() |
|
|
|
|
|
if self.debug: |
|
|
logger.debug(f"Detected field name after caption: {repr(field_name)}") |
|
|
|
|
|
|
|
|
field_name_to_value_state = { |
|
|
"duration": FSMState.DURATION_VALUE, |
|
|
"genres": FSMState.GENRES_VALUE, |
|
|
"keyscale": FSMState.KEYSCALE_VALUE, |
|
|
"language": FSMState.LANGUAGE_VALUE, |
|
|
"timesignature": FSMState.TIMESIG_VALUE, |
|
|
} |
|
|
|
|
|
if field_name in field_name_to_value_state: |
|
|
|
|
|
old_state = self.state |
|
|
self.state = field_name_to_value_state[field_name] |
|
|
self.position_in_state = 0 |
|
|
self.accumulated_value = "" |
|
|
self.accumulated_token_ids = [] |
|
|
self.caption_ending = False |
|
|
self.pending_field_name = "" |
|
|
|
|
|
if self.debug: |
|
|
logger.debug(f"FSM transition (caption ending): {old_state.name} -> {self.state.name}") |
|
|
else: |
|
|
|
|
|
if self.debug: |
|
|
logger.warning(f"Unknown field name after caption: {repr(field_name)}, forcing transition") |
|
|
self.caption_ending = False |
|
|
self.pending_field_name = "" |
|
|
self._transition_to_next_state() |
|
|
|
|
|
elif self.state == FSMState.KEYSCALE_VALUE: |
|
|
if generated_token_id == self.newline_token: |
|
|
|
|
|
self._transition_to_next_state() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.state in self.fixed_strings: |
|
|
return |
|
|
else: |
|
|
|
|
|
self.accumulated_token_ids.append(generated_token_id) |
|
|
|
|
|
self.accumulated_value += token_str |
|
|
|
|
|
elif self.state == FSMState.LANGUAGE_VALUE: |
|
|
if generated_token_id == self.newline_token: |
|
|
|
|
|
self._transition_to_next_state() |
|
|
if self.state in self.fixed_strings: |
|
|
return |
|
|
else: |
|
|
|
|
|
self.accumulated_token_ids.append(generated_token_id) |
|
|
|
|
|
self.accumulated_value += token_str |
|
|
|
|
|
|