ChuxiJ commited on
Commit
f41792a
·
1 Parent(s): 9e64ac5

support user meta control for lm

Browse files
acestep/constrained_logits_processor.py ADDED
@@ -0,0 +1,1593 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from enum import Enum, auto
3
+ from typing import Optional, Dict, Any, Tuple, List, Callable, Set
4
+ from loguru import logger
5
+ from transformers import AutoTokenizer
6
+ from transformers.generation.logits_process import LogitsProcessor
7
+ import os
8
+ import torch
9
+
10
+
11
+ # ==============================================================================
12
+ # FSM States for Constrained Decoding
13
+ # ==============================================================================
14
+ class FSMState(Enum):
15
+ """Finite State Machine states for metadata generation"""
16
+ THINK_TAG = auto() # Generating "<think>"
17
+ NEWLINE_AFTER_THINK = auto() # Generating "\n" after <think>
18
+ BPM_NAME = auto() # Generating "bpm: "
19
+ BPM_VALUE = auto() # Generating numeric value 30-300
20
+ NEWLINE_AFTER_BPM = auto() # Generating "\n" after bpm value
21
+ DURATION_NAME = auto() # Generating "duration: "
22
+ DURATION_VALUE = auto() # Generating numeric value 10-600
23
+ NEWLINE_AFTER_DURATION = auto()
24
+ GENRES_NAME = auto() # Generating "genres: "
25
+ GENRES_VALUE = auto() # Generating any non-empty string
26
+ NEWLINE_AFTER_GENRES = auto()
27
+ KEYSCALE_NAME = auto() # Generating "keyscale: "
28
+ KEYSCALE_VALUE = auto() # Generating keyscale pattern
29
+ NEWLINE_AFTER_KEYSCALE = auto()
30
+ TIMESIG_NAME = auto() # Generating "timesignature: "
31
+ TIMESIG_VALUE = auto() # Generating 2, 3, 4, or 6
32
+ NEWLINE_AFTER_TIMESIG = auto()
33
+ THINK_END_TAG = auto() # Generating "</think>"
34
+ CODES_GENERATION = auto() # Generating audio codes (no constraints)
35
+ COMPLETED = auto() # Generation completed
36
+
37
+
38
+ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
39
+ """
40
+ FSM-driven LogitsProcessor that constrains generation to produce valid metadata.
41
+
42
+ This processor enforces the following format:
43
+ <think>
44
+ bpm: [30-300]
45
+ duration: [10-600]
46
+ genres: [any non-empty string]
47
+ keyscale: [A-G][#/♭]? [major/minor]
48
+ timesignature: [2/3/4/6]
49
+ </think>
50
+
51
+ It uses token masking (setting invalid token logits to -inf) to enforce constraints.
52
+ For numeric fields, it uses early-blocking to prevent out-of-range values.
53
+ For field transitions (e.g., end of numeric value), it compares P(newline) vs P(digit).
54
+ """
55
+
56
+ def __init__(
57
+ self,
58
+ tokenizer: AutoTokenizer,
59
+ enabled: bool = True,
60
+ debug: bool = False,
61
+ genres_vocab_path: Optional[str] = None,
62
+ skip_genres: bool = True,
63
+ ):
64
+ """
65
+ Initialize the constrained logits processor.
66
+
67
+ This processor should be initialized once when loading the LLM and reused
68
+ for all generations. Use update_caption() before each generation to update
69
+ the caption-based genre filtering.
70
+
71
+ Args:
72
+ tokenizer: The tokenizer to use for encoding/decoding
73
+ enabled: Whether to enable constrained decoding
74
+ debug: Whether to print debug information
75
+ genres_vocab_path: Path to genres vocabulary file (one genre per line)
76
+ If None, defaults to "acestep/genres_vocab.txt"
77
+ skip_genres: Whether to skip genres generation in metadata (default True)
78
+ """
79
+ self.tokenizer = tokenizer
80
+ self.enabled = enabled
81
+ self.debug = debug
82
+ self.skip_genres = skip_genres
83
+ self.caption: Optional[str] = None # Set via update_caption() before each generation
84
+
85
+ # User-provided metadata fields (optional)
86
+ # If provided, these fields will be used directly instead of generating
87
+ # Format: {"bpm": "120", "duration": "234", "keyscale": "G major", "timesignature": "4", "genres": "Pop Rock"}
88
+ self.user_provided_metadata: Dict[str, Optional[str]] = {
89
+ "bpm": None,
90
+ "duration": None,
91
+ "keyscale": None,
92
+ "timesignature": None,
93
+ "genres": None,
94
+ }
95
+
96
+ # Temperature settings for different generation phases (set per-generation)
97
+ # If set, the processor will apply temperature scaling (divide logits by temperature)
98
+ # Note: Set base sampler temperature to 1.0 when using processor-based temperature
99
+ self.metadata_temperature: Optional[float] = None
100
+ self.codes_temperature: Optional[float] = None
101
+
102
+ # Duration constraint for codes generation
103
+ # 5 codes = 1 second, so target_codes = target_duration * 5
104
+ self.target_duration: Optional[float] = None # User-specified duration in seconds
105
+ self.target_codes: Optional[int] = None # Computed target codes count
106
+ self.codes_count: int = 0 # Counter for generated codes
107
+
108
+ # Current state
109
+ self.state = FSMState.THINK_TAG
110
+ self.position_in_state = 0 # Position within current state's fixed string
111
+ self.accumulated_value = "" # For numeric/text value accumulation (legacy, for compatibility)
112
+ self.accumulated_token_ids: List[int] = [] # Token ID sequence for keyscale (and other fields)
113
+
114
+ # Token queue for user-provided fields (injected directly without generation)
115
+ self.user_field_token_queue: List[int] = []
116
+ self.current_user_field: Optional[str] = None # Current field being injected
117
+
118
+ # Pre-compute token IDs for efficiency
119
+ self._precompute_tokens()
120
+
121
+ # Genres vocabulary for constrained decoding
122
+ self.genres_vocab_path = genres_vocab_path or os.path.join(
123
+ os.path.dirname(os.path.abspath(__file__)), "genres_vocab.txt"
124
+ )
125
+ self.genres_vocab: List[str] = [] # Full vocab
126
+ self.genres_vocab_mtime: float = 0.0
127
+ self.genres_trie: Dict = {} # Trie for full vocab (fallback)
128
+ self.caption_genres_trie: Dict = {} # Trie for caption-matched genres (priority)
129
+ self.caption_matched_genres: List[str] = [] # Genres matched from caption
130
+ self._char_to_tokens: Dict[str, set] = {} # Precomputed char -> token IDs mapping
131
+
132
+ # Precompute token mappings once (O(vocab_size), runs once at init)
133
+ self._precompute_char_token_mapping()
134
+
135
+ # Field definitions (needed before building prefix trees)
136
+ self.field_specs = {
137
+ "bpm": {"min": 30, "max": 300},
138
+ "duration": {"min": 10, "max": 600},
139
+ "timesignature": {"valid_values": [2, 3, 4, 6]},
140
+ }
141
+
142
+ # Build valid numeric values for BPM, Duration, Timesignature
143
+ # These will be used to build prefix trees based on actual tokenization
144
+ self.valid_bpm_values = [str(v) for v in range(self.field_specs["bpm"]["min"], self.field_specs["bpm"]["max"] + 1)]
145
+ self.valid_duration_values = [str(v) for v in range(self.field_specs["duration"]["min"], self.field_specs["duration"]["max"] + 1)]
146
+ self.valid_timesig_values = [str(v) for v in self.field_specs["timesignature"]["valid_values"]]
147
+
148
+ # Build keyscale prefix tree (requires _char_to_tokens to be initialized)
149
+ self.keyscale_prefix_tree = self._build_keyscale_prefix_tree()
150
+
151
+ # Build numeric prefix trees (BPM, Duration, Timesignature) with context
152
+ # IMPORTANT: State machine generates "bpm:" (no space), but tokenizer sees "bpm: " (with space)
153
+ # Use same logic as keyscale: context_prefix_for_matching (no space) and context_prefix_for_tokenization (with space)
154
+ self.bpm_prefix_tree = self._build_numeric_prefix_tree(
155
+ self.valid_bpm_values,
156
+ context_prefix_for_matching="bpm:",
157
+ context_prefix_for_tokenization="bpm: "
158
+ )
159
+ self.duration_prefix_tree = self._build_numeric_prefix_tree(
160
+ self.valid_duration_values,
161
+ context_prefix_for_matching="duration:",
162
+ context_prefix_for_tokenization="duration: "
163
+ )
164
+ self.timesig_prefix_tree = self._build_numeric_prefix_tree(
165
+ self.valid_timesig_values,
166
+ context_prefix_for_matching="timesignature:",
167
+ context_prefix_for_tokenization="timesignature: "
168
+ )
169
+
170
+ self._load_genres_vocab()
171
+
172
+ # Note: Caption-based genre filtering is initialized via update_caption() before each generation
173
+
174
+ # Fixed strings for each state
175
+ # IMPORTANT: Do NOT include trailing space after colon - tokenizer will handle spacing
176
+ # All matching should be done at token level, not string level
177
+ # NOTE: NEWLINE_AFTER_* states are removed - field values generate newline directly and transition to next field
178
+ self.fixed_strings = {
179
+ FSMState.THINK_TAG: "<think>",
180
+ FSMState.NEWLINE_AFTER_THINK: "\n",
181
+ FSMState.BPM_NAME: "bpm:",
182
+ FSMState.DURATION_NAME: "duration:",
183
+ FSMState.GENRES_NAME: "genres:",
184
+ FSMState.KEYSCALE_NAME: "keyscale:",
185
+ FSMState.TIMESIG_NAME: "timesignature:",
186
+ FSMState.THINK_END_TAG: "</think>",
187
+ }
188
+
189
+ # State transitions - build dynamically based on skip_genres
190
+ self._build_state_transitions()
191
+
192
+ def _get_next_field_state(self, current_field: str) -> Optional[FSMState]:
193
+ """
194
+ Get the next field state. Always returns the next field's NAME state,
195
+ even if the field is user-provided (we still need to generate the field name).
196
+
197
+ Args:
198
+ current_field: Current field name ("bpm", "duration", "genres", "keyscale", "timesignature")
199
+
200
+ Returns:
201
+ Next FSMState (NAME state of next field), or THINK_END_TAG if no more fields
202
+ """
203
+ field_order = ["bpm", "duration", "genres", "keyscale", "timesignature"]
204
+ field_to_state = {
205
+ "bpm": FSMState.BPM_NAME,
206
+ "duration": FSMState.DURATION_NAME,
207
+ "genres": FSMState.GENRES_NAME,
208
+ "keyscale": FSMState.KEYSCALE_NAME,
209
+ "timesignature": FSMState.TIMESIG_NAME,
210
+ }
211
+
212
+ try:
213
+ current_idx = field_order.index(current_field)
214
+ except ValueError:
215
+ return FSMState.THINK_END_TAG
216
+
217
+ # Find next field in order
218
+ for i in range(current_idx + 1, len(field_order)):
219
+ field = field_order[i]
220
+
221
+ # Skip genres if skip_genres is True
222
+ if field == "genres" and self.skip_genres:
223
+ continue
224
+
225
+ # Return the next field's NAME state (even if user-provided, we still generate field name)
226
+ return field_to_state[field]
227
+
228
+ # No more fields, go to THINK_END_TAG
229
+ return FSMState.THINK_END_TAG
230
+
231
+ def _build_state_transitions(self):
232
+ """Build state transition map based on skip_genres and user-provided metadata."""
233
+ self.next_state = {
234
+ FSMState.THINK_TAG: FSMState.NEWLINE_AFTER_THINK,
235
+ FSMState.NEWLINE_AFTER_THINK: FSMState.BPM_NAME, # Always start with BPM
236
+ FSMState.THINK_END_TAG: FSMState.CODES_GENERATION,
237
+ FSMState.CODES_GENERATION: FSMState.COMPLETED,
238
+ }
239
+
240
+ # Build transitions for all fields (even if user-provided, we still need to generate field name)
241
+ # Field order: bpm -> duration -> genres -> keyscale -> timesignature
242
+
243
+ # BPM field: NAME -> VALUE -> next field
244
+ self.next_state[FSMState.BPM_NAME] = FSMState.BPM_VALUE
245
+ self.next_state[FSMState.BPM_VALUE] = self._get_next_field_state("bpm")
246
+
247
+ # Duration field: NAME -> VALUE -> next field
248
+ self.next_state[FSMState.DURATION_NAME] = FSMState.DURATION_VALUE
249
+ self.next_state[FSMState.DURATION_VALUE] = self._get_next_field_state("duration")
250
+
251
+ # Genres field (only if not skipped): NAME -> VALUE -> next field
252
+ if not self.skip_genres:
253
+ self.next_state[FSMState.GENRES_NAME] = FSMState.GENRES_VALUE
254
+ self.next_state[FSMState.GENRES_VALUE] = self._get_next_field_state("genres")
255
+
256
+ # Keyscale field: NAME -> VALUE -> next field
257
+ self.next_state[FSMState.KEYSCALE_NAME] = FSMState.KEYSCALE_VALUE
258
+ self.next_state[FSMState.KEYSCALE_VALUE] = self._get_next_field_state("keyscale")
259
+
260
+ # Timesignature field: NAME -> VALUE -> THINK_END_TAG
261
+ self.next_state[FSMState.TIMESIG_NAME] = FSMState.TIMESIG_VALUE
262
+ self.next_state[FSMState.TIMESIG_VALUE] = FSMState.THINK_END_TAG
263
+
264
+ def set_skip_genres(self, skip: bool):
265
+ """Set whether to skip genres generation and rebuild state transitions."""
266
+ self.skip_genres = skip
267
+ self._build_state_transitions()
268
+
269
+ def set_user_metadata(self, metadata: Optional[Dict[str, Optional[str]]] = None):
270
+ """
271
+ Set user-provided metadata fields. Fields that are provided will be used directly
272
+ instead of generating. Fields that are None will be generated.
273
+
274
+ Args:
275
+ metadata: Dictionary with optional fields:
276
+ - "bpm": Optional[str] - e.g., "120"
277
+ - "duration": Optional[str] - e.g., "234"
278
+ - "keyscale": Optional[str] - e.g., "G major"
279
+ - "timesignature": Optional[str] - e.g., "4"
280
+ - "genres": Optional[str] - e.g., "Pop Rock"
281
+ If None, clears all user-provided metadata.
282
+ """
283
+ if metadata is None:
284
+ metadata = {}
285
+
286
+ # Update user-provided metadata
287
+ for field in ["bpm", "duration", "keyscale", "timesignature", "genres"]:
288
+ if field in metadata:
289
+ self.user_provided_metadata[field] = metadata[field]
290
+ else:
291
+ self.user_provided_metadata[field] = None
292
+
293
+ # Rebuild state transitions to skip provided fields
294
+ self._build_state_transitions()
295
+
296
+ if self.debug:
297
+ provided_fields = [k for k, v in self.user_provided_metadata.items() if v is not None]
298
+ if provided_fields:
299
+ logger.debug(f"User provided metadata fields: {provided_fields}")
300
+ else:
301
+ logger.debug("No user-provided metadata, all fields will be generated")
302
+
303
+ def _precompute_tokens(self):
304
+ """Pre-compute commonly used token IDs for efficiency."""
305
+ # Digit tokens (0-9)
306
+ self.digit_tokens = {}
307
+ for d in range(10):
308
+ tokens = self.tokenizer.encode(str(d), add_special_tokens=False)
309
+ if tokens:
310
+ self.digit_tokens[d] = tokens[-1] # Take last token (in case of prefix)
311
+
312
+ # Newline token
313
+ newline_tokens = self.tokenizer.encode("\n", add_special_tokens=False)
314
+ self.newline_token = newline_tokens[-1] if newline_tokens else None
315
+
316
+ # Note tokens for keyscale (A-G)
317
+ self.note_tokens = {}
318
+ for note in "ABCDEFG":
319
+ tokens = self.tokenizer.encode(note, add_special_tokens=False)
320
+ if tokens:
321
+ self.note_tokens[note] = tokens[-1]
322
+
323
+ # Sharp/flat tokens
324
+ self.sharp_tokens = []
325
+ for s in ["#", "♯"]:
326
+ tokens = self.tokenizer.encode(s, add_special_tokens=False)
327
+ if tokens:
328
+ self.sharp_tokens.append(tokens[-1])
329
+
330
+ self.flat_tokens = []
331
+ for f in ["b", "♭"]:
332
+ tokens = self.tokenizer.encode(f, add_special_tokens=False)
333
+ if tokens:
334
+ self.flat_tokens.append(tokens[-1])
335
+
336
+ # Space token
337
+ space_tokens = self.tokenizer.encode(" ", add_special_tokens=False)
338
+ self.space_token = space_tokens[-1] if space_tokens else None
339
+
340
+ # Major/minor tokens (we'll encode the full words)
341
+ self.major_start_tokens = []
342
+ self.minor_start_tokens = []
343
+ for prefix in ["m", "M"]:
344
+ tokens = self.tokenizer.encode(prefix, add_special_tokens=False)
345
+ if tokens:
346
+ if prefix.lower() == "m":
347
+ self.minor_start_tokens.append(tokens[-1])
348
+ self.major_start_tokens.append(tokens[-1]) # "major" also starts with m
349
+
350
+ # Vocab size
351
+ self.vocab_size = len(self.tokenizer)
352
+
353
+ # Comma token for multi-genre support
354
+ comma_tokens = self.tokenizer.encode(",", add_special_tokens=False)
355
+ self.comma_token = comma_tokens[-1] if comma_tokens else None
356
+
357
+ # EOS token for duration-constrained codes generation
358
+ self.eos_token_id = self.tokenizer.eos_token_id
359
+
360
+ # Build valid keyscales set (prefix tree will be built after _char_to_tokens is initialized)
361
+ # 7 notes × 5 accidentals (none, #, b, ♯, ♭) × 2 modes = 70 valid combinations
362
+ notes = ['A', 'B', 'C', 'D', 'E', 'F', 'G']
363
+ accidentals = ['', '#', 'b', '♯', '♭'] # empty + ASCII sharp/flat + Unicode sharp/flat
364
+ modes = ['major', 'minor']
365
+
366
+ self.valid_keyscales = set()
367
+ for note in notes:
368
+ for acc in accidentals:
369
+ for mode in modes:
370
+ self.valid_keyscales.add(f"{note}{acc} {mode}")
371
+
372
+ # keyscale_prefix_tree will be built in _precompute_char_token_mapping() after _char_to_tokens is ready
373
+ # Numeric prefix trees will be built after field_specs is defined
374
+
375
+ def _build_keyscale_prefix_tree(self) -> Dict[Tuple[int, ...], Set[int]]:
376
+ """
377
+ Build keyscale prefix to allowed tokens mapping based on ACTUAL tokenization.
378
+
379
+ IMPORTANT: Uses token ID sequences as keys, NOT strings, to avoid tokenization mismatches.
380
+
381
+ CRITICAL FIX: The tokenizer may merge the context's trailing space into the next token.
382
+ For example:
383
+ - "keyscale: " tokenizes to [10563, 2246, 25, 220] -> ['keys', 'cale', ':', ' ']
384
+ - "keyscale: G major" tokenizes to [10563, 2246, 25, 479, 3598] -> ['keys', 'cale', ':', ' G', ' major']
385
+ The space ' ' (220) is merged into ' G' (479), so we can't use simple slicing.
386
+
387
+ Strategy:
388
+ 1. For each keyscale (e.g., "G major"), encode the FULL string "keyscale: G major"
389
+ 2. Tokenize to get: [10563, 2246, 25, 479, 3598] -> ['keys', 'cale', ':', ' G', ' major']
390
+ 3. Find where context prefix ends by matching token sequences (handling space merging)
391
+ 4. Extract keyscale value tokens: [479, 3598] (for "G major")
392
+ 5. Build prefix tree using token ID sequences as keys
393
+
394
+ This ensures we get the exact tokenization that occurs during generation.
395
+ """
396
+ prefix_to_tokens: Dict[Tuple[int, ...], Set[int]] = {}
397
+
398
+ # Context prefix that appears before keyscale value
399
+ # IMPORTANT: The state machine generates "keyscale:" (no space), but when tokenizing
400
+ # the full string "keyscale: G major", the tokenizer includes space, so we need to
401
+ # match the actual tokenization behavior.
402
+ #
403
+ # Strategy:
404
+ # 1. Use "keyscale:" (no space) to match the state machine's output
405
+ # 2. But when building prefix tree, use "keyscale: " (with space) + keyscale to match actual tokenization
406
+ context_prefix_for_matching = "keyscale:" # What state machine generates
407
+ context_prefix_for_tokenization = "keyscale: " # What tokenizer sees in full string
408
+
409
+ # First, tokenize the context (without space) to know its token sequence for matching
410
+ context_token_ids = self.tokenizer.encode(context_prefix_for_matching, add_special_tokens=False)
411
+
412
+ if self.debug:
413
+ context_tokens_str = [self.tokenizer.decode([t]) for t in context_token_ids]
414
+ logger.debug(f"Context for matching 'keyscale:' tokenizes to {context_token_ids} -> {context_tokens_str}")
415
+
416
+ # For each valid keyscale, encode full string and extract value tokens
417
+ for keyscale in self.valid_keyscales:
418
+ # Step 1: Encode full string "keyscale: {keyscale}" (with space, as tokenizer sees it)
419
+ full_text = context_prefix_for_tokenization + keyscale
420
+ full_token_ids = self.tokenizer.encode(full_text, add_special_tokens=False)
421
+
422
+ # Step 2: Find where context ends in full_token_ids
423
+ # We match using context_prefix_for_matching ("keyscale:") token sequence
424
+ # because that's what the state machine actually generates
425
+ context_end_idx = None
426
+
427
+ # Try exact prefix match using context_prefix_for_matching token sequence
428
+ if len(full_token_ids) >= len(context_token_ids):
429
+ if full_token_ids[:len(context_token_ids)] == context_token_ids:
430
+ context_end_idx = len(context_token_ids)
431
+
432
+ if context_end_idx is None:
433
+ if self.debug:
434
+ logger.warning(f"Could not find context prefix in full tokenization of '{full_text}', skipping")
435
+ continue
436
+
437
+ # Step 3: Extract keyscale value tokens (everything after context)
438
+ keyscale_token_ids = full_token_ids[context_end_idx:]
439
+
440
+ # Step 4: Verify we extracted some tokens (sanity check)
441
+ if not keyscale_token_ids:
442
+ if self.debug:
443
+ logger.warning(f"No tokens extracted for keyscale '{keyscale}', skipping")
444
+ continue
445
+
446
+ # Step 5: Verify first token is a note (A-G)
447
+ # This is critical: the first token of keyscale value must be a note
448
+ first_token_id = keyscale_token_ids[0]
449
+ first_token_str = self.tokenizer.decode([first_token_id])
450
+ # Check if first token starts with a note (A-G, case insensitive, with optional leading space)
451
+ first_char = first_token_str.lstrip()[0].upper() if first_token_str.lstrip() else ""
452
+ if first_char not in "ABCDEFG":
453
+ # This keyscale's first token is not a note - skip it
454
+ if self.debug:
455
+ logger.debug(f"Skipping keyscale '{keyscale}': first token is '{first_token_str}' (id={first_token_id}), not a note")
456
+ continue
457
+
458
+ # Step 6: Build prefix mappings from keyscale value tokens
459
+ # Use token ID sequences as keys (not strings) to avoid tokenization mismatches
460
+ for i in range(len(keyscale_token_ids) + 1):
461
+ # Current token sequence prefix (empty tuple for start)
462
+ token_prefix = tuple(keyscale_token_ids[:i])
463
+
464
+ if token_prefix not in prefix_to_tokens:
465
+ prefix_to_tokens[token_prefix] = set()
466
+
467
+ if i < len(keyscale_token_ids):
468
+ # Add next token as allowed for current prefix
469
+ next_token_id = keyscale_token_ids[i]
470
+ prefix_to_tokens[token_prefix].add(next_token_id)
471
+ else:
472
+ # Complete keyscale should allow newline
473
+ if self.newline_token:
474
+ prefix_to_tokens[token_prefix].add(self.newline_token)
475
+
476
+ if self.debug:
477
+ logger.debug(f"Built keyscale prefix tree with {len(prefix_to_tokens)} token sequence prefixes")
478
+ # Check empty prefix (start of keyscale value)
479
+ empty_prefix = tuple()
480
+ if empty_prefix in prefix_to_tokens:
481
+ first_tokens = prefix_to_tokens[empty_prefix]
482
+ decoded_first = [(t, repr(self.tokenizer.decode([t]))) for t in sorted(first_tokens)]
483
+ logger.debug(f"First tokens allowed (empty prefix): {decoded_first}")
484
+
485
+ return prefix_to_tokens
486
+
487
+ def _build_numeric_prefix_tree(
488
+ self,
489
+ valid_values: List[str],
490
+ context_prefix_for_matching: str = "",
491
+ context_prefix_for_tokenization: str = ""
492
+ ) -> Dict[Tuple[int, ...], Set[int]]:
493
+ """
494
+ Build prefix tree for numeric field based on actual tokenization with context.
495
+
496
+ IMPORTANT: Uses token ID sequences as keys, NOT strings, to avoid tokenization mismatches.
497
+
498
+ Args:
499
+ valid_values: List of valid numeric strings (e.g., ["30", "31", ..., "300"])
500
+ context_prefix_for_matching: Context string that state machine generates (e.g., "bpm:") - no space
501
+ context_prefix_for_tokenization: Context string for tokenization (e.g., "bpm: ") - with space
502
+
503
+ Returns:
504
+ Dict mapping token ID sequence prefix -> set of allowed token IDs
505
+ """
506
+ prefix_to_tokens: Dict[Tuple[int, ...], Set[int]] = {}
507
+
508
+ # Encode context for matching (what state machine generates, no space)
509
+ context_token_ids = self.tokenizer.encode(context_prefix_for_matching, add_special_tokens=False) if context_prefix_for_matching else []
510
+
511
+ # For each valid value, encode it with context and build prefix mappings
512
+ for value_str in valid_values:
513
+ # Encode value WITH context (with space) to match actual tokenization
514
+ full_text = context_prefix_for_tokenization + value_str
515
+ token_ids = self.tokenizer.encode(full_text, add_special_tokens=False)
516
+
517
+ # Find where context ends in full_token_ids using context_prefix_for_matching token sequence
518
+ context_end_idx = None
519
+ if len(token_ids) >= len(context_token_ids):
520
+ if token_ids[:len(context_token_ids)] == context_token_ids:
521
+ context_end_idx = len(context_token_ids)
522
+
523
+ if context_end_idx is None:
524
+ if self.debug:
525
+ logger.warning(f"Could not find context prefix in full tokenization of '{full_text}', skipping")
526
+ continue
527
+
528
+ # Extract only tokens that belong to the value itself (skip context tokens)
529
+ value_token_ids = token_ids[context_end_idx:]
530
+
531
+ # Build prefix mappings using token ID sequences as keys
532
+ for i in range(len(value_token_ids) + 1):
533
+ # Current token sequence prefix (empty tuple for start)
534
+ token_prefix = tuple(value_token_ids[:i])
535
+
536
+ if token_prefix not in prefix_to_tokens:
537
+ prefix_to_tokens[token_prefix] = set()
538
+
539
+ if i < len(value_token_ids):
540
+ # Add next token as allowed for current prefix
541
+ next_token_id = value_token_ids[i]
542
+ prefix_to_tokens[token_prefix].add(next_token_id)
543
+ else:
544
+ # Complete value should allow newline
545
+ if self.newline_token:
546
+ prefix_to_tokens[token_prefix].add(self.newline_token)
547
+
548
+ return prefix_to_tokens
549
+
550
+ def diagnose_keyscale_prefix_tree(self):
551
+ """
552
+ Diagnose the keyscale prefix tree to help debug generation bias.
553
+ Call this method to print detailed information about allowed tokens at each prefix.
554
+ """
555
+ print("=" * 60)
556
+ print("KEYSCALE PREFIX TREE DIAGNOSIS")
557
+ print("=" * 60)
558
+
559
+ # Check empty prefix (first token)
560
+ if "" in self.keyscale_prefix_tree:
561
+ first_tokens = self.keyscale_prefix_tree[""]
562
+ print(f"\n[Empty prefix] Allowed first tokens ({len(first_tokens)} total):")
563
+ for t in sorted(first_tokens):
564
+ decoded = self.tokenizer.decode([t])
565
+ print(f" Token {t}: {repr(decoded)}")
566
+ else:
567
+ print("\nWARNING: Empty prefix not in tree!")
568
+
569
+ # Check some common prefixes
570
+ test_prefixes = ["A", "B", "C", "D", "E", "F", "G"]
571
+ for prefix in test_prefixes:
572
+ # Try both with and without potential tokenizer artifacts
573
+ for test_key in [prefix, prefix + " "]:
574
+ if test_key in self.keyscale_prefix_tree:
575
+ tokens = self.keyscale_prefix_tree[test_key]
576
+ print(f"\n[Prefix {repr(test_key)}] Allowed tokens ({len(tokens)}):")
577
+ for t in sorted(tokens):
578
+ decoded = self.tokenizer.decode([t])
579
+ print(f" Token {t}: {repr(decoded)}")
580
+
581
+ # Show some complete keyscales that should be valid
582
+ print(f"\n[Valid keyscales] Total: {len(self.valid_keyscales)}")
583
+ sample = sorted(list(self.valid_keyscales))[:10]
584
+ for ks in sample:
585
+ print(f" {repr(ks)}")
586
+
587
+ print("=" * 60)
588
+
589
+ def _load_genres_vocab(self):
590
+ """
591
+ Load genres vocabulary from file. Supports hot reload by checking file mtime.
592
+ File format: one genre per line, lines starting with # are comments.
593
+ """
594
+ if not os.path.exists(self.genres_vocab_path):
595
+ if self.debug:
596
+ logger.debug(f"Genres vocab file not found: {self.genres_vocab_path}")
597
+ return
598
+
599
+ try:
600
+ mtime = os.path.getmtime(self.genres_vocab_path)
601
+ if mtime <= self.genres_vocab_mtime:
602
+ return # File hasn't changed
603
+
604
+ with open(self.genres_vocab_path, 'r', encoding='utf-8') as f:
605
+ genres = []
606
+ for line in f:
607
+ line = line.strip()
608
+ if line and not line.startswith('#'):
609
+ genres.append(line.lower())
610
+
611
+ self.genres_vocab = genres
612
+ self.genres_vocab_mtime = mtime
613
+ self._build_genres_trie()
614
+
615
+ if self.debug:
616
+ logger.debug(f"Loaded {len(self.genres_vocab)} genres from {self.genres_vocab_path}")
617
+ except Exception as e:
618
+ logger.warning(f"Failed to load genres vocab: {e}")
619
+
620
+ def _build_genres_trie(self):
621
+ """
622
+ Build a trie (prefix tree) from genres vocabulary for efficient prefix matching.
623
+ Each node is a dict with:
624
+ - '_end': True if this node represents a complete genre
625
+ - other keys: next characters in the trie
626
+ """
627
+ self.genres_trie = {}
628
+
629
+ for genre in self.genres_vocab:
630
+ node = self.genres_trie
631
+ for char in genre:
632
+ if char not in node:
633
+ node[char] = {}
634
+ node = node[char]
635
+ node['_end'] = True # Mark end of a complete genre
636
+
637
+ if self.debug:
638
+ logger.debug(f"Built genres trie with {len(self.genres_vocab)} entries")
639
+
640
+ def _extract_caption_genres(self, caption: str):
641
+ """
642
+ Extract genres from the user's caption that match entries in the vocabulary.
643
+ This creates a smaller trie for faster and more relevant genre generation.
644
+
645
+ Strategy (optimized - O(words * max_genre_len) instead of O(vocab_size)):
646
+ 1. Extract words/phrases from caption
647
+ 2. For each word, use trie to find all vocab entries that START with this word
648
+ 3. Build a separate trie from matched genres
649
+ """
650
+ if not caption or not self.genres_vocab:
651
+ return
652
+
653
+ caption_lower = caption.lower()
654
+ matched_genres = set()
655
+
656
+ # Extract words from caption (split by common delimiters)
657
+ import re
658
+ words = re.split(r'[,\s\-_/\\|]+', caption_lower)
659
+ words = [w.strip() for w in words if w.strip() and len(w.strip()) >= 2]
660
+
661
+ # For each word, find genres in trie that start with this word
662
+ for word in words:
663
+ # Find all genres starting with this word using trie traversal
664
+ node = self._get_genres_trie_node(word)
665
+ if node is not None:
666
+ # Collect all complete genres under this node
667
+ self._collect_complete_genres(node, word, matched_genres)
668
+
669
+ # Also check if any word appears as a substring in short genres (< 20 chars)
670
+ # This is a quick check for common single-word genres
671
+ genres_set = set(self.genres_vocab)
672
+ for word in words:
673
+ if word in genres_set:
674
+ matched_genres.add(word)
675
+
676
+ if not matched_genres:
677
+ if self.debug:
678
+ logger.debug(f"No genres matched in caption, using full vocab")
679
+ return
680
+
681
+ # Build a trie from matched genres
682
+ self.caption_matched_genres = list(matched_genres)
683
+ self.caption_genres_trie = {}
684
+
685
+ for genre in matched_genres:
686
+ node = self.caption_genres_trie
687
+ for char in genre:
688
+ if char not in node:
689
+ node[char] = {}
690
+ node = node[char]
691
+ node['_end'] = True
692
+
693
+ if self.debug:
694
+ logger.debug(f"Matched {len(matched_genres)} genres from caption: {list(matched_genres)[:5]}...")
695
+
696
+ def _collect_complete_genres(self, node: Dict, prefix: str, result: set, max_depth: int = 50):
697
+ """
698
+ Recursively collect all complete genres under a trie node.
699
+ Limited depth to avoid too many matches.
700
+ """
701
+ if max_depth <= 0:
702
+ return
703
+
704
+ if node.get('_end', False):
705
+ result.add(prefix)
706
+
707
+ # Limit total collected genres to avoid slowdown
708
+ if len(result) >= 100:
709
+ return
710
+
711
+ for char, child_node in node.items():
712
+ if char not in ('_end', '_tokens'):
713
+ self._collect_complete_genres(child_node, prefix + char, result, max_depth - 1)
714
+
715
+ def _precompute_char_token_mapping(self):
716
+ """
717
+ Precompute mapping from characters to token IDs and token decoded texts.
718
+ This allows O(1) lookup instead of calling tokenizer.encode()/decode() at runtime.
719
+
720
+ Time complexity: O(vocab_size) - runs once during initialization
721
+
722
+ Note: Many subword tokenizers (like Qwen) add space prefixes to tokens.
723
+ We need to handle both the raw first char and the first non-space char.
724
+ """
725
+ self._char_to_tokens: Dict[str, set] = {}
726
+ self._token_to_text: Dict[int, str] = {} # Precomputed decoded text for each token
727
+
728
+ # For each token in vocabulary, get its decoded text
729
+ for token_id in range(self.vocab_size):
730
+ try:
731
+ text = self.tokenizer.decode([token_id])
732
+
733
+ if not text:
734
+ continue
735
+
736
+ # Store the decoded text (normalized to lowercase)
737
+ # Keep leading spaces for proper concatenation (e.g., " rock" in "pop rock")
738
+ # Only rstrip trailing whitespace, unless it's a pure whitespace token
739
+ text_lower = text.lower()
740
+ if text_lower.strip(): # Has non-whitespace content
741
+ normalized_text = text_lower.rstrip()
742
+ else: # Pure whitespace token
743
+ normalized_text = " " # Normalize to single space
744
+ self._token_to_text[token_id] = normalized_text
745
+
746
+ # Map first character (including space) to this token
747
+ first_char = text[0].lower()
748
+ if first_char not in self._char_to_tokens:
749
+ self._char_to_tokens[first_char] = set()
750
+ self._char_to_tokens[first_char].add(token_id)
751
+
752
+ # Also map first non-space character to this token
753
+ # This handles tokenizers that add space prefixes (e.g., " pop" -> maps to 'p')
754
+ stripped_text = text.lstrip()
755
+ if stripped_text and stripped_text != text:
756
+ first_nonspace_char = stripped_text[0].lower()
757
+ if first_nonspace_char not in self._char_to_tokens:
758
+ self._char_to_tokens[first_nonspace_char] = set()
759
+ self._char_to_tokens[first_nonspace_char].add(token_id)
760
+
761
+ except Exception:
762
+ continue
763
+
764
+ if self.debug:
765
+ logger.debug(f"Precomputed char->token mapping for {len(self._char_to_tokens)} unique characters")
766
+
767
+ def _try_reload_genres_vocab(self):
768
+ """Check if genres vocab file has been updated and reload if necessary."""
769
+ if not os.path.exists(self.genres_vocab_path):
770
+ return
771
+
772
+ try:
773
+ mtime = os.path.getmtime(self.genres_vocab_path)
774
+ if mtime > self.genres_vocab_mtime:
775
+ self._load_genres_vocab()
776
+ except Exception:
777
+ pass # Ignore errors during hot reload check
778
+
779
+ def _get_genres_trie_node(self, prefix: str) -> Optional[Dict]:
780
+ """
781
+ Get the trie node for a given prefix.
782
+ Returns None if the prefix is not valid (no genres start with this prefix).
783
+ """
784
+ node = self.genres_trie
785
+ for char in prefix.lower():
786
+ if char not in node:
787
+ return None
788
+ node = node[char]
789
+ return node
790
+
791
+ def _is_complete_genre(self, text: str) -> bool:
792
+ """Check if the given text is a complete genre in the vocabulary."""
793
+ node = self._get_genres_trie_node(text.strip())
794
+ return node is not None and node.get('_end', False)
795
+
796
+ def _get_trie_node_from_trie(self, trie: Dict, prefix: str) -> Optional[Dict]:
797
+ """Get a trie node from a specific trie (helper for caption vs full trie)."""
798
+ node = trie
799
+ for char in prefix.lower():
800
+ if char not in node:
801
+ return None
802
+ node = node[char]
803
+ return node
804
+
805
+ def _get_allowed_genres_tokens(self) -> List[int]:
806
+ """
807
+ Get allowed tokens for genres field based on trie matching.
808
+
809
+ The entire genres string (including commas) must match a complete entry in the vocab.
810
+ For example, if vocab contains "pop, rock, jazz", the generated string must exactly
811
+ match that entry - we don't treat commas as separators for individual genres.
812
+
813
+ Strategy:
814
+ 1. If caption-matched genres exist, use that smaller trie first (faster + more relevant)
815
+ 2. If no caption matches or prefix not in caption trie, fallback to full vocab trie
816
+ 3. Get valid next characters from current trie node
817
+ 4. For each candidate token, verify the full decoded text forms a valid trie prefix
818
+ """
819
+ if not self.genres_vocab:
820
+ # No vocab loaded, allow all except newline if empty
821
+ return []
822
+
823
+ # Use the full accumulated value (don't split by comma - treat as single entry)
824
+ accumulated = self.accumulated_value.lower()
825
+ current_genre_prefix = accumulated.strip()
826
+
827
+ # Determine which trie to use: caption-matched (priority) or full vocab (fallback)
828
+ use_caption_trie = False
829
+ current_node = None
830
+
831
+ # Try caption-matched trie first if available
832
+ if self.caption_genres_trie:
833
+ if current_genre_prefix == "":
834
+ current_node = self.caption_genres_trie
835
+ use_caption_trie = True
836
+ else:
837
+ current_node = self._get_trie_node_from_trie(self.caption_genres_trie, current_genre_prefix)
838
+ if current_node is not None:
839
+ use_caption_trie = True
840
+
841
+ # Fallback to full vocab trie
842
+ if current_node is None:
843
+ if current_genre_prefix == "":
844
+ current_node = self.genres_trie
845
+ else:
846
+ current_node = self._get_genres_trie_node(current_genre_prefix)
847
+
848
+ if current_node is None:
849
+ # Invalid prefix, force newline to end
850
+ if self.newline_token:
851
+ return [self.newline_token]
852
+ return []
853
+
854
+ # Get valid next characters from trie node
855
+ valid_next_chars = set(k for k in current_node.keys() if k not in ('_end', '_tokens'))
856
+
857
+ # If current value is a complete genre, allow newline to end
858
+ is_complete = current_node.get('_end', False)
859
+
860
+ if not valid_next_chars:
861
+ # No more characters to match, only allow newline if complete
862
+ allowed = set()
863
+ if is_complete and self.newline_token:
864
+ allowed.add(self.newline_token)
865
+ return list(allowed)
866
+
867
+ # Collect candidate tokens based on first character
868
+ candidate_tokens = set()
869
+ for char in valid_next_chars:
870
+ if char in self._char_to_tokens:
871
+ candidate_tokens.update(self._char_to_tokens[char])
872
+
873
+ # Select the appropriate trie for validation
874
+ active_trie = self.caption_genres_trie if use_caption_trie else self.genres_trie
875
+
876
+ # Validate each candidate token: check if prefix + decoded_token is a valid trie prefix
877
+ allowed = set()
878
+ for token_id in candidate_tokens:
879
+ # Use precomputed decoded text (already normalized)
880
+ decoded_normalized = self._token_to_text.get(token_id, "")
881
+
882
+ if not decoded_normalized or not decoded_normalized.strip():
883
+ # Token decodes to empty or only whitespace - allow if space/comma is a valid next char
884
+ if ' ' in valid_next_chars or ',' in valid_next_chars:
885
+ allowed.add(token_id)
886
+ continue
887
+
888
+ # Build new prefix by appending decoded token
889
+ # Handle space-prefixed tokens (e.g., " rock" from "pop rock")
890
+ if decoded_normalized.startswith(' ') or decoded_normalized.startswith(','):
891
+ # Token has leading space/comma - append directly
892
+ new_prefix = current_genre_prefix + decoded_normalized
893
+ else:
894
+ new_prefix = current_genre_prefix + decoded_normalized
895
+
896
+ # Check if new_prefix is a valid prefix in the active trie
897
+ new_node = self._get_trie_node_from_trie(active_trie, new_prefix)
898
+ if new_node is not None:
899
+ allowed.add(token_id)
900
+
901
+ # If current value is a complete genre, also allow newline
902
+ if is_complete and self.newline_token:
903
+ allowed.add(self.newline_token)
904
+
905
+ return list(allowed)
906
+
907
+ def reset(self):
908
+ """Reset the processor state for a new generation."""
909
+ self.state = FSMState.THINK_TAG
910
+ self.position_in_state = 0
911
+ self.accumulated_value = "" # Legacy, kept for compatibility
912
+ self.accumulated_token_ids = [] # Reset token ID sequence
913
+ self.codes_count = 0 # Reset codes counter
914
+ self.user_field_token_queue = [] # Reset user field token queue
915
+ self.current_user_field = None # Reset current user field
916
+
917
+ def set_target_duration(self, duration: Optional[float]):
918
+ """
919
+ Set the target duration for codes generation.
920
+
921
+ Args:
922
+ duration: Target duration in seconds. If None, no duration constraint is applied.
923
+ 5 codes = 1 second, so target_codes = duration * 5.
924
+ """
925
+ self.target_duration = duration
926
+ if duration is not None and duration > 0:
927
+ self.target_codes = int(duration * 5)
928
+ if self.debug:
929
+ logger.debug(f"Set target duration: {duration}s -> {self.target_codes} codes")
930
+ else:
931
+ self.target_codes = None
932
+ if self.debug:
933
+ logger.debug("Target duration cleared, no duration constraint")
934
+
935
+ def update_caption(self, caption: Optional[str]):
936
+ """
937
+ Update the caption and rebuild the caption-matched genres trie.
938
+ Call this before each generation to prioritize genres from the new caption.
939
+
940
+ Args:
941
+ caption: User's input caption. If None or empty, clears caption matching.
942
+ """
943
+ # Check for hot reload of genres vocabulary
944
+ self._try_reload_genres_vocab()
945
+
946
+ self.caption = caption
947
+ self.caption_genres_trie = {}
948
+ self.caption_matched_genres = []
949
+
950
+ if caption:
951
+ self._extract_caption_genres(caption)
952
+
953
+ # Also reset FSM state for new generation
954
+ self.reset()
955
+
956
+ def _get_allowed_tokens_for_fixed_string(self, fixed_str: str) -> List[int]:
957
+ """
958
+ Get the token IDs that can continue the fixed string from current position.
959
+ Returns list of allowed token IDs.
960
+ """
961
+ remaining = fixed_str[self.position_in_state:]
962
+ if not remaining:
963
+ return []
964
+
965
+ # Try to find tokens that match the beginning of remaining string
966
+ allowed = []
967
+
968
+ # Try encoding progressively longer prefixes
969
+ for end in range(1, len(remaining) + 1):
970
+ prefix = remaining[:end]
971
+ tokens = self.tokenizer.encode(prefix, add_special_tokens=False)
972
+ if tokens:
973
+ # The first token that matches is valid
974
+ allowed.append(tokens[0])
975
+
976
+ # Also check single character encoding
977
+ first_char = remaining[0]
978
+ char_tokens = self.tokenizer.encode(first_char, add_special_tokens=False)
979
+ if char_tokens:
980
+ allowed.extend(char_tokens)
981
+
982
+ return list(set(allowed))
983
+
984
+ def _get_allowed_digit_tokens(self, min_val: int, max_val: int) -> List[int]:
985
+ """
986
+ Get allowed digit tokens based on accumulated value and range constraints.
987
+ Uses early-blocking to prevent out-of-range values.
988
+ """
989
+ if not self.accumulated_value:
990
+ # First digit: determine valid starting digits
991
+ allowed_digits = set()
992
+ for v in range(min_val, max_val + 1):
993
+ allowed_digits.add(int(str(v)[0]))
994
+ return [self.digit_tokens[d] for d in allowed_digits if d in self.digit_tokens]
995
+
996
+ current = int(self.accumulated_value)
997
+ allowed = []
998
+
999
+ for d in range(10):
1000
+ new_value = int(self.accumulated_value + str(d))
1001
+ # Check if this digit could lead to a valid final value
1002
+ # A digit is valid if:
1003
+ # 1. new_value <= max_val (not already exceeded)
1004
+ # 2. new_value could potentially reach >= min_val
1005
+ # (i.e., new_value * 10^k >= min_val for some k >= 0)
1006
+
1007
+ if new_value > max_val:
1008
+ continue # Already exceeded max
1009
+
1010
+ # Check if we can still reach min_val
1011
+ # If new_value is already >= min_val, it's valid
1012
+ # If new_value < min_val, we need more digits, but new_value * 10 must not exceed max
1013
+ if new_value >= min_val:
1014
+ allowed.append(d)
1015
+ elif new_value * 10 <= max_val:
1016
+ # Can add more digits
1017
+ allowed.append(d)
1018
+
1019
+ return [self.digit_tokens[d] for d in allowed if d in self.digit_tokens]
1020
+
1021
+ def _get_allowed_numeric_tokens(self, prefix_tree: Dict[Tuple[int, ...], Set[int]]) -> List[int]:
1022
+ """
1023
+ Get allowed tokens for numeric field using the precomputed prefix tree.
1024
+
1025
+ IMPORTANT: Uses token ID sequence as key (not string) to avoid tokenization mismatches.
1026
+
1027
+ Args:
1028
+ prefix_tree: Precomputed prefix tree mapping token ID sequence -> set of allowed token IDs
1029
+
1030
+ Returns:
1031
+ List of allowed token IDs for current accumulated_token_ids
1032
+ """
1033
+ token_prefix = tuple(self.accumulated_token_ids)
1034
+
1035
+ if token_prefix in prefix_tree:
1036
+ return list(prefix_tree[token_prefix])
1037
+
1038
+ # No valid continuation found - return empty list
1039
+ # The caller will handle this by forcing newline to end the field
1040
+ return []
1041
+
1042
+ def _should_end_numeric_field(self, logits: torch.Tensor, min_val: int, max_val: int) -> bool:
1043
+ """
1044
+ Determine if we should end the current numeric field.
1045
+ Returns True if P(newline) > P(any valid digit) AND current value is valid.
1046
+ """
1047
+ if not self.accumulated_value:
1048
+ return False
1049
+
1050
+ current = int(self.accumulated_value)
1051
+ if current < min_val or current > max_val:
1052
+ return False # Can't end yet, value not in range
1053
+
1054
+ # Get probabilities
1055
+ probs = torch.softmax(logits, dim=-1)
1056
+
1057
+ newline_prob = probs[0, self.newline_token].item() if self.newline_token else 0
1058
+
1059
+ # Get max probability among valid digit tokens
1060
+ allowed_digits = self._get_allowed_digit_tokens(min_val, max_val)
1061
+ if not allowed_digits:
1062
+ return True # No more digits possible, must end
1063
+
1064
+ max_digit_prob = max(probs[0, t].item() for t in allowed_digits)
1065
+
1066
+ if self.debug:
1067
+ logger.debug(f"Numeric field decision: newline_prob={newline_prob:.4f}, max_digit_prob={max_digit_prob:.4f}")
1068
+
1069
+ return newline_prob > max_digit_prob
1070
+
1071
+ def _should_end_text_field(self, logits: torch.Tensor) -> bool:
1072
+ """
1073
+ Determine if we should end a text field (genres).
1074
+ Returns True if P(newline) > P(any other token) AND we have some content.
1075
+ """
1076
+ if not self.accumulated_value.strip():
1077
+ return False # Need at least some content
1078
+
1079
+ probs = torch.softmax(logits, dim=-1)
1080
+ newline_prob = probs[0, self.newline_token].item() if self.newline_token else 0
1081
+
1082
+ # Get max probability among non-newline tokens
1083
+ masked_probs = probs.clone()
1084
+ if self.newline_token:
1085
+ masked_probs[0, self.newline_token] = 0
1086
+ max_other_prob = masked_probs[0].max().item()
1087
+
1088
+ return newline_prob > max_other_prob
1089
+
1090
+ def _get_allowed_keyscale_tokens(self) -> List[int]:
1091
+ """
1092
+ Get allowed tokens for keyscale field using the precomputed prefix tree.
1093
+ Uses token ID sequence as key (not string) to avoid tokenization mismatches.
1094
+ """
1095
+ # Use token ID sequence as key
1096
+ token_prefix = tuple(self.accumulated_token_ids)
1097
+
1098
+ if token_prefix in self.keyscale_prefix_tree:
1099
+ return list(self.keyscale_prefix_tree[token_prefix])
1100
+
1101
+ # Fallback: if we somehow drifted off (shouldn't happen with constrained decoding),
1102
+ # return empty to force newline logic or stop.
1103
+ return []
1104
+
1105
+ def _is_keyscale_complete(self) -> bool:
1106
+ """
1107
+ Check if keyscale value is complete and valid.
1108
+ Uses token ID sequence to check if current prefix allows newline.
1109
+ """
1110
+ token_prefix = tuple(self.accumulated_token_ids)
1111
+ # If current token sequence prefix is in tree and allows newline, it's complete
1112
+ if token_prefix in self.keyscale_prefix_tree:
1113
+ return self.newline_token in self.keyscale_prefix_tree[token_prefix]
1114
+ return False
1115
+
1116
+ def _get_allowed_timesig_tokens(self) -> List[int]:
1117
+ """
1118
+ Get allowed tokens for timesignature field using the precomputed prefix tree.
1119
+ Uses token ID sequence as key (not string) to avoid tokenization mismatches.
1120
+ """
1121
+ token_prefix = tuple(self.accumulated_token_ids)
1122
+
1123
+ if token_prefix in self.timesig_prefix_tree:
1124
+ return list(self.timesig_prefix_tree[token_prefix])
1125
+
1126
+ # No valid continuation found - return empty list
1127
+ # The caller will handle this by forcing newline to end the field
1128
+ return []
1129
+
1130
+ def __call__(
1131
+ self,
1132
+ input_ids: torch.LongTensor,
1133
+ scores: torch.FloatTensor,
1134
+ ) -> torch.FloatTensor:
1135
+ """
1136
+ Apply constrained decoding by modifying logits.
1137
+
1138
+ Args:
1139
+ input_ids: [batch_size, seq_len] input token IDs
1140
+ scores: [batch_size, vocab_size] logits for next token
1141
+
1142
+ Returns:
1143
+ Modified scores with invalid tokens masked to -inf and temperature scaling applied
1144
+ """
1145
+ if not self.enabled:
1146
+ return self._apply_temperature_scaling(scores)
1147
+
1148
+ if self.state == FSMState.COMPLETED:
1149
+ return self._apply_temperature_scaling(scores)
1150
+
1151
+ if self.state == FSMState.CODES_GENERATION:
1152
+ # Apply duration constraint in codes generation phase
1153
+ if self.target_codes is not None and self.eos_token_id is not None:
1154
+ if self.codes_count < self.target_codes:
1155
+ # Block EOS token until target codes count is reached
1156
+ scores[:, self.eos_token_id] = float('-inf')
1157
+ if self.debug:
1158
+ logger.debug(f"Codes generation: {self.codes_count}/{self.target_codes}, blocking EOS")
1159
+ else:
1160
+ # Force EOS token when target codes count is reached
1161
+ mask = torch.full_like(scores, float('-inf'))
1162
+ mask[:, self.eos_token_id] = 0
1163
+ scores = scores + mask
1164
+ if self.debug:
1165
+ logger.debug(f"Codes generation: {self.codes_count}/{self.target_codes}, forcing EOS")
1166
+ return self._apply_temperature_scaling(scores)
1167
+
1168
+ batch_size = scores.shape[0]
1169
+
1170
+ # Process each sequence in batch
1171
+ for b in range(batch_size):
1172
+ result = self._process_single_sequence(input_ids[b], scores[b:b+1])
1173
+ scores[b] = result[0] # result is [1, vocab_size], need [vocab_size]
1174
+
1175
+ # Apply temperature scaling after constraint masking
1176
+ return self._apply_temperature_scaling(scores)
1177
+
1178
+ def _apply_temperature_scaling(self, scores: torch.FloatTensor) -> torch.FloatTensor:
1179
+ """
1180
+ Apply temperature scaling based on current generation phase.
1181
+
1182
+ Temperature scaling: logits = logits / temperature
1183
+ - Lower temperature (< 1.0) makes distribution sharper (more deterministic)
1184
+ - Higher temperature (> 1.0) makes distribution flatter (more diverse)
1185
+
1186
+ Args:
1187
+ scores: [batch_size, vocab_size] logits
1188
+
1189
+ Returns:
1190
+ Temperature-scaled logits
1191
+ """
1192
+ # Determine which temperature to use based on current state
1193
+ if self.state == FSMState.CODES_GENERATION or self.state == FSMState.COMPLETED:
1194
+ temperature = self.codes_temperature
1195
+ else:
1196
+ temperature = self.metadata_temperature
1197
+
1198
+ # If no temperature is set for this phase, return scores unchanged
1199
+ if temperature is None:
1200
+ return scores
1201
+
1202
+ # Avoid division by zero
1203
+ if temperature <= 0:
1204
+ temperature = 1e-6
1205
+
1206
+ # Apply temperature scaling
1207
+ return scores / temperature
1208
+
1209
+ def _get_user_provided_field_tokens(self, field_name: str) -> Optional[List[int]]:
1210
+ """
1211
+ Get token sequence for a user-provided field (field_name + value + newline).
1212
+ Uses the same tokenization logic as prefix tree building.
1213
+
1214
+ Args:
1215
+ field_name: Field name ("bpm", "duration", "keyscale", "timesignature", "genres")
1216
+
1217
+ Returns:
1218
+ List of token IDs for the complete field, or None if field is not provided
1219
+ """
1220
+ value = self.user_provided_metadata.get(field_name)
1221
+ if value is None:
1222
+ return None
1223
+
1224
+ # Build full field string with space (matching prefix tree tokenization)
1225
+ field_to_prefix = {
1226
+ "bpm": "bpm: ",
1227
+ "duration": "duration: ",
1228
+ "keyscale": "keyscale: ",
1229
+ "timesignature": "timesignature: ",
1230
+ "genres": "genres: ",
1231
+ }
1232
+ prefix = field_to_prefix[field_name]
1233
+ full_text = f"{prefix}{value}\n"
1234
+
1235
+ # Tokenize the full field
1236
+ tokens = self.tokenizer.encode(full_text, add_special_tokens=False)
1237
+
1238
+ # Extract only the field tokens (skip the prefix tokens that match state machine output)
1239
+ # The state machine generates "field_name:" (no space), so we need to match that
1240
+ prefix_for_matching = field_name + ":"
1241
+ prefix_tokens = self.tokenizer.encode(prefix_for_matching, add_special_tokens=False)
1242
+
1243
+ # Find where prefix ends in full tokens
1244
+ if len(tokens) >= len(prefix_tokens) and tokens[:len(prefix_tokens)] == prefix_tokens:
1245
+ # Return tokens after prefix (field value + newline)
1246
+ return tokens[len(prefix_tokens):]
1247
+ else:
1248
+ # Fallback: return all tokens (shouldn't happen if tokenization is consistent)
1249
+ if self.debug:
1250
+ logger.warning(f"Could not match prefix tokens for field {field_name}, using all tokens")
1251
+ return tokens
1252
+
1253
+ def _process_single_sequence(
1254
+ self,
1255
+ input_ids: torch.LongTensor,
1256
+ scores: torch.FloatTensor,
1257
+ ) -> torch.FloatTensor:
1258
+ """Process a single sequence and return modified scores."""
1259
+
1260
+ # Check if we have tokens in queue for user-provided field
1261
+ # If so, inject the next token directly
1262
+ if self.user_field_token_queue:
1263
+ mask = torch.full_like(scores, float('-inf'))
1264
+ next_token = self.user_field_token_queue[0]
1265
+ mask[0, next_token] = 0
1266
+ scores = scores + mask
1267
+ return scores
1268
+
1269
+ # Create mask (all -inf initially)
1270
+ mask = torch.full_like(scores, float('-inf'))
1271
+
1272
+ if self.state in self.fixed_strings:
1273
+ # Fixed string state: force specific tokens
1274
+ allowed = self._get_allowed_tokens_for_fixed_string(self.fixed_strings[self.state])
1275
+ if allowed:
1276
+ for t in allowed:
1277
+ mask[0, t] = 0
1278
+ # Apply mask
1279
+ scores = scores + mask
1280
+
1281
+ # Update position tracking
1282
+ # We need to check if the selected token completes the fixed string
1283
+ # This will be done in update_state() after token selection
1284
+ else:
1285
+ # Position exceeds string, move to next state
1286
+ old_state = self.state
1287
+ self._transition_to_next_state()
1288
+ # Avoid infinite recursion: if we're still in a fixed_strings state, just return scores
1289
+ if self.state in self.fixed_strings:
1290
+ # This shouldn't happen, but if it does, just return scores to avoid recursion
1291
+ if self.debug:
1292
+ logger.warning(f"State transition from {old_state.name} to {self.state.name} still in fixed_strings, avoiding recursion")
1293
+ return scores
1294
+ return self._process_single_sequence(input_ids, torch.zeros_like(scores))
1295
+
1296
+ elif self.state == FSMState.BPM_VALUE:
1297
+ # Check if field is user-provided and we haven't started injecting yet
1298
+ if self.user_provided_metadata["bpm"] is not None and not self.user_field_token_queue and not self.accumulated_token_ids:
1299
+ # Initialize token queue with field value tokens (value + newline)
1300
+ value = self.user_provided_metadata["bpm"]
1301
+ # Tokenize " value\n" (space + value + newline) to match actual tokenization
1302
+ value_text = f" {value}\n"
1303
+ value_tokens = self.tokenizer.encode(value_text, add_special_tokens=False)
1304
+ if value_tokens:
1305
+ self.user_field_token_queue = value_tokens
1306
+ self.current_user_field = "bpm"
1307
+ # Inject first token
1308
+ mask[0, value_tokens[0]] = 0
1309
+ scores = scores + mask
1310
+ return scores
1311
+
1312
+ # Allow valid numeric tokens using prefix tree (supports multi-digit tokens like "120")
1313
+ allowed = self._get_allowed_numeric_tokens(self.bpm_prefix_tree)
1314
+ for t in allowed:
1315
+ mask[0, t] = 0
1316
+
1317
+ # Also allow newline if current token sequence prefix allows it
1318
+ # Check if current token sequence is in prefix tree and allows newline
1319
+ token_prefix = tuple(self.accumulated_token_ids)
1320
+ if token_prefix in self.bpm_prefix_tree and self.newline_token in self.bpm_prefix_tree[token_prefix]:
1321
+ mask[0, self.newline_token] = 0
1322
+
1323
+ scores = scores + mask
1324
+
1325
+ elif self.state == FSMState.DURATION_VALUE:
1326
+ # Check if field is user-provided and we haven't started injecting yet
1327
+ if self.user_provided_metadata["duration"] is not None and not self.user_field_token_queue and not self.accumulated_token_ids:
1328
+ # Initialize token queue with field value tokens (value + newline)
1329
+ value = self.user_provided_metadata["duration"]
1330
+ value_text = f" {value}\n"
1331
+ value_tokens = self.tokenizer.encode(value_text, add_special_tokens=False)
1332
+ if value_tokens:
1333
+ self.user_field_token_queue = value_tokens
1334
+ self.current_user_field = "duration"
1335
+ # Inject first token
1336
+ mask[0, value_tokens[0]] = 0
1337
+ scores = scores + mask
1338
+ return scores
1339
+
1340
+ # If target_duration is set, force generate that exact value
1341
+ if self.target_duration is not None:
1342
+ target_str = str(int(self.target_duration))
1343
+ current_pos = len(self.accumulated_value)
1344
+
1345
+ if current_pos < len(target_str):
1346
+ # Force the next digit
1347
+ next_digit = int(target_str[current_pos])
1348
+ if next_digit in self.digit_tokens:
1349
+ mask[0, self.digit_tokens[next_digit]] = 0
1350
+ else:
1351
+ # All digits generated, force newline
1352
+ if self.newline_token:
1353
+ mask[0, self.newline_token] = 0
1354
+ self._transition_to_next_state()
1355
+
1356
+ scores = scores + mask
1357
+ else:
1358
+ # Normal duration generation with range constraint
1359
+ # Allow valid numeric tokens using prefix tree (supports multi-digit tokens like "60", "120")
1360
+ allowed = self._get_allowed_numeric_tokens(self.duration_prefix_tree)
1361
+ for t in allowed:
1362
+ mask[0, t] = 0
1363
+
1364
+ # Also allow newline if current token sequence prefix allows it
1365
+ token_prefix = tuple(self.accumulated_token_ids)
1366
+ if token_prefix in self.duration_prefix_tree and self.newline_token in self.duration_prefix_tree[token_prefix]:
1367
+ mask[0, self.newline_token] = 0
1368
+
1369
+ scores = scores + mask
1370
+
1371
+ elif self.state == FSMState.GENRES_VALUE:
1372
+ # Check if field is user-provided and we haven't started injecting yet
1373
+ if self.user_provided_metadata["genres"] is not None and not self.user_field_token_queue and not self.accumulated_value:
1374
+ # Initialize token queue with field value tokens (value + newline)
1375
+ value = self.user_provided_metadata["genres"]
1376
+ value_text = f" {value}\n"
1377
+ value_tokens = self.tokenizer.encode(value_text, add_special_tokens=False)
1378
+ if value_tokens:
1379
+ self.user_field_token_queue = value_tokens
1380
+ self.current_user_field = "genres"
1381
+ # Inject first token
1382
+ mask[0, value_tokens[0]] = 0
1383
+ scores = scores + mask
1384
+ return scores
1385
+
1386
+ # Try to hot-reload genres vocab if file has changed
1387
+ self._try_reload_genres_vocab()
1388
+
1389
+ # Get allowed tokens based on genres vocabulary
1390
+ allowed = self._get_allowed_genres_tokens()
1391
+
1392
+ if allowed:
1393
+ # Use vocabulary-constrained decoding
1394
+ for t in allowed:
1395
+ mask[0, t] = 0
1396
+ scores = scores + mask
1397
+ elif self.genres_vocab:
1398
+ # Vocab is loaded but no valid continuation found
1399
+ # Force newline to end the field
1400
+ if self.newline_token:
1401
+ mask[0, self.newline_token] = 0
1402
+ if self.debug:
1403
+ logger.debug(f"No valid genre continuation for '{self.accumulated_value}', forcing newline")
1404
+ scores = scores + mask
1405
+ else:
1406
+ # Fallback: no vocab loaded, use probability-based ending
1407
+ if self._should_end_text_field(scores):
1408
+ if self.newline_token:
1409
+ mask[0, self.newline_token] = 0
1410
+ self._transition_to_next_state()
1411
+ scores = scores + mask
1412
+ else:
1413
+ # Allow any token except newline if we don't have content yet
1414
+ if not self.accumulated_value.strip():
1415
+ if self.newline_token:
1416
+ scores[0, self.newline_token] = float('-inf')
1417
+ # Otherwise, don't constrain (fallback behavior)
1418
+
1419
+ elif self.state == FSMState.KEYSCALE_VALUE:
1420
+ # Check if field is user-provided and we haven't started injecting yet
1421
+ if self.user_provided_metadata["keyscale"] is not None and not self.user_field_token_queue and not self.accumulated_token_ids:
1422
+ # Initialize token queue with field value tokens (value + newline)
1423
+ value = self.user_provided_metadata["keyscale"]
1424
+ value_text = f" {value}\n"
1425
+ value_tokens = self.tokenizer.encode(value_text, add_special_tokens=False)
1426
+ if value_tokens:
1427
+ self.user_field_token_queue = value_tokens
1428
+ self.current_user_field = "keyscale"
1429
+ # Inject first token
1430
+ mask[0, value_tokens[0]] = 0
1431
+ scores = scores + mask
1432
+ return scores
1433
+
1434
+ # Check if current token sequence is complete (allows newline)
1435
+ token_prefix = tuple(self.accumulated_token_ids)
1436
+ if token_prefix in self.keyscale_prefix_tree and self.newline_token in self.keyscale_prefix_tree[token_prefix]:
1437
+ # Complete keyscale, allow newline
1438
+ if self.newline_token:
1439
+ mask[0, self.newline_token] = 0
1440
+ scores = scores + mask
1441
+ else:
1442
+ # Not complete, allow valid continuation tokens
1443
+ allowed = self._get_allowed_keyscale_tokens()
1444
+ if allowed:
1445
+ for t in allowed:
1446
+ mask[0, t] = 0
1447
+ scores = scores + mask
1448
+ else:
1449
+ # No valid tokens found - force newline to end field
1450
+ # This handles edge cases where keyscale format is unexpected
1451
+ if self.newline_token:
1452
+ mask[0, self.newline_token] = 0
1453
+ scores = scores + mask
1454
+
1455
+ elif self.state == FSMState.TIMESIG_VALUE:
1456
+ # Check if field is user-provided and we haven't started injecting yet
1457
+ if self.user_provided_metadata["timesignature"] is not None and not self.user_field_token_queue and not self.accumulated_token_ids:
1458
+ # Initialize token queue with field value tokens (value + newline)
1459
+ value = self.user_provided_metadata["timesignature"]
1460
+ value_text = f" {value}\n"
1461
+ value_tokens = self.tokenizer.encode(value_text, add_special_tokens=False)
1462
+ if value_tokens:
1463
+ self.user_field_token_queue = value_tokens
1464
+ self.current_user_field = "timesignature"
1465
+ # Inject first token
1466
+ mask[0, value_tokens[0]] = 0
1467
+ scores = scores + mask
1468
+ return scores
1469
+
1470
+ # Check if current token sequence is complete (allows newline)
1471
+ token_prefix = tuple(self.accumulated_token_ids)
1472
+ if token_prefix in self.timesig_prefix_tree and self.newline_token in self.timesig_prefix_tree[token_prefix]:
1473
+ # Complete value, allow newline
1474
+ if self.newline_token:
1475
+ mask[0, self.newline_token] = 0
1476
+ scores = scores + mask
1477
+ else:
1478
+ # Not complete, allow valid continuation tokens
1479
+ allowed = self._get_allowed_timesig_tokens()
1480
+ for t in allowed:
1481
+ mask[0, t] = 0
1482
+ scores = scores + mask
1483
+
1484
+ return scores
1485
+
1486
+ def _transition_to_next_state(self):
1487
+ """Transition to the next FSM state."""
1488
+ if self.state in self.next_state:
1489
+ old_state = self.state
1490
+ self.state = self.next_state[self.state]
1491
+ self.position_in_state = 0
1492
+ self.accumulated_value = "" # Legacy, kept for compatibility
1493
+ self.accumulated_token_ids = [] # Reset token ID sequence for new field
1494
+ if self.debug:
1495
+ logger.debug(f"FSM transition: {old_state.name} -> {self.state.name}")
1496
+
1497
+ def update_state(self, generated_token_id: int):
1498
+ """
1499
+ Update internal state after a token has been generated.
1500
+ This should be called after each token generation.
1501
+
1502
+ Args:
1503
+ generated_token_id: The token ID that was just generated
1504
+ """
1505
+ if not self.enabled:
1506
+ return
1507
+
1508
+ if self.state == FSMState.COMPLETED:
1509
+ return
1510
+
1511
+ if self.state == FSMState.CODES_GENERATION:
1512
+ # Count generated codes for duration constraint
1513
+ self.codes_count += 1
1514
+ if self.debug and self.target_codes is not None:
1515
+ logger.debug(f"Codes count: {self.codes_count}/{self.target_codes}")
1516
+ return
1517
+
1518
+ # Handle user-provided field token injection
1519
+ if self.user_field_token_queue:
1520
+ # Verify the generated token matches the expected token from queue
1521
+ expected_token = self.user_field_token_queue[0]
1522
+ if generated_token_id != expected_token:
1523
+ if self.debug:
1524
+ logger.warning(f"Expected token {expected_token} but got {generated_token_id} for user-provided field {self.current_user_field}")
1525
+
1526
+ # Remove consumed token from queue
1527
+ self.user_field_token_queue.pop(0)
1528
+
1529
+ # If queue is empty, field injection is complete, transition to next state
1530
+ if not self.user_field_token_queue:
1531
+ if self.debug:
1532
+ logger.debug(f"Completed injection of user-provided field: {self.current_user_field}")
1533
+ field_name = self.current_user_field
1534
+ self.current_user_field = None
1535
+
1536
+ # Transition to next state (skip VALUE state since we already injected everything)
1537
+ # The next state should be determined by _get_next_field_state
1538
+ next_state = self._get_next_field_state(field_name)
1539
+ if next_state:
1540
+ old_state = self.state
1541
+ self.state = next_state
1542
+ self.position_in_state = 0
1543
+ self.accumulated_value = ""
1544
+ self.accumulated_token_ids = []
1545
+ if self.debug:
1546
+ logger.debug(f"FSM transition (after user field injection): {old_state.name} -> {self.state.name}")
1547
+ else:
1548
+ # All fields done, go to THINK_END_TAG
1549
+ self._transition_to_next_state()
1550
+ return
1551
+
1552
+ token_str = self.tokenizer.decode([generated_token_id])
1553
+
1554
+ if self.debug:
1555
+ logger.debug(f"Generated token: {repr(token_str)} (id={generated_token_id}), state={self.state.name}")
1556
+
1557
+ if self.state in self.fixed_strings:
1558
+ # Update position in fixed string
1559
+ fixed_str = self.fixed_strings[self.state]
1560
+ self.position_in_state += len(token_str)
1561
+
1562
+ # Check if we've completed the fixed string
1563
+ if self.position_in_state >= len(fixed_str):
1564
+ self._transition_to_next_state()
1565
+
1566
+ elif self.state in [FSMState.BPM_VALUE, FSMState.DURATION_VALUE, FSMState.TIMESIG_VALUE]:
1567
+ # Accumulate numeric value using token ID sequence
1568
+ if generated_token_id == self.newline_token:
1569
+ # Newline ends the field
1570
+ self._transition_to_next_state()
1571
+ else:
1572
+ # Add token ID to sequence (for prefix tree lookup)
1573
+ self.accumulated_token_ids.append(generated_token_id)
1574
+ # Also update legacy accumulated_value for compatibility
1575
+ if token_str.strip().isdigit():
1576
+ self.accumulated_value += token_str.strip()
1577
+
1578
+ elif self.state == FSMState.GENRES_VALUE:
1579
+ if generated_token_id == self.newline_token:
1580
+ self._transition_to_next_state()
1581
+ else:
1582
+ # Genres still uses string-based trie, so keep accumulated_value
1583
+ self.accumulated_value += token_str
1584
+
1585
+ elif self.state == FSMState.KEYSCALE_VALUE:
1586
+ if generated_token_id == self.newline_token:
1587
+ self._transition_to_next_state()
1588
+ else:
1589
+ # Add token ID to sequence (for prefix tree lookup)
1590
+ self.accumulated_token_ids.append(generated_token_id)
1591
+ # Also update legacy accumulated_value for compatibility
1592
+ self.accumulated_value += token_str
1593
+
acestep/llm_inference.py CHANGED
@@ -3,11 +3,9 @@
3
  Handles all LM-related operations including initialization and generation
4
  """
5
  import os
6
- import re
7
  import traceback
8
  import time
9
- from enum import Enum, auto
10
- from typing import Optional, Dict, Any, Tuple, List, Callable, Set
11
  from contextlib import contextmanager
12
 
13
  import torch
@@ -20,1086 +18,7 @@ from transformers.generation.logits_process import (
20
  RepetitionPenaltyLogitsProcessor,
21
  LogitsProcessor,
22
  )
23
-
24
-
25
- # ==============================================================================
26
- # FSM States for Constrained Decoding
27
- # ==============================================================================
28
- class FSMState(Enum):
29
- """Finite State Machine states for metadata generation"""
30
- THINK_TAG = auto() # Generating "<think>"
31
- NEWLINE_AFTER_THINK = auto() # Generating "\n" after <think>
32
- BPM_NAME = auto() # Generating "bpm: "
33
- BPM_VALUE = auto() # Generating numeric value 30-300
34
- NEWLINE_AFTER_BPM = auto() # Generating "\n" after bpm value
35
- DURATION_NAME = auto() # Generating "duration: "
36
- DURATION_VALUE = auto() # Generating numeric value 10-600
37
- NEWLINE_AFTER_DURATION = auto()
38
- GENRES_NAME = auto() # Generating "genres: "
39
- GENRES_VALUE = auto() # Generating any non-empty string
40
- NEWLINE_AFTER_GENRES = auto()
41
- KEYSCALE_NAME = auto() # Generating "keyscale: "
42
- KEYSCALE_VALUE = auto() # Generating keyscale pattern
43
- NEWLINE_AFTER_KEYSCALE = auto()
44
- TIMESIG_NAME = auto() # Generating "timesignature: "
45
- TIMESIG_VALUE = auto() # Generating 2, 3, 4, or 6
46
- NEWLINE_AFTER_TIMESIG = auto()
47
- THINK_END_TAG = auto() # Generating "</think>"
48
- CODES_GENERATION = auto() # Generating audio codes (no constraints)
49
- COMPLETED = auto() # Generation completed
50
-
51
-
52
- class MetadataConstrainedLogitsProcessor(LogitsProcessor):
53
- """
54
- FSM-driven LogitsProcessor that constrains generation to produce valid metadata.
55
-
56
- This processor enforces the following format:
57
- <think>
58
- bpm: [30-300]
59
- duration: [10-600]
60
- genres: [any non-empty string]
61
- keyscale: [A-G][#/♭]? [major/minor]
62
- timesignature: [2/3/4/6]
63
- </think>
64
-
65
- It uses token masking (setting invalid token logits to -inf) to enforce constraints.
66
- For numeric fields, it uses early-blocking to prevent out-of-range values.
67
- For field transitions (e.g., end of numeric value), it compares P(newline) vs P(digit).
68
- """
69
-
70
- def __init__(
71
- self,
72
- tokenizer: AutoTokenizer,
73
- enabled: bool = True,
74
- debug: bool = False,
75
- genres_vocab_path: Optional[str] = None,
76
- skip_genres: bool = True,
77
- ):
78
- """
79
- Initialize the constrained logits processor.
80
-
81
- This processor should be initialized once when loading the LLM and reused
82
- for all generations. Use update_caption() before each generation to update
83
- the caption-based genre filtering.
84
-
85
- Args:
86
- tokenizer: The tokenizer to use for encoding/decoding
87
- enabled: Whether to enable constrained decoding
88
- debug: Whether to print debug information
89
- genres_vocab_path: Path to genres vocabulary file (one genre per line)
90
- If None, defaults to "acestep/genres_vocab.txt"
91
- skip_genres: Whether to skip genres generation in metadata (default True)
92
- """
93
- self.tokenizer = tokenizer
94
- self.enabled = enabled
95
- self.debug = debug
96
- self.skip_genres = skip_genres
97
- self.caption: Optional[str] = None # Set via update_caption() before each generation
98
-
99
- # Temperature settings for different generation phases (set per-generation)
100
- # If set, the processor will apply temperature scaling (divide logits by temperature)
101
- # Note: Set base sampler temperature to 1.0 when using processor-based temperature
102
- self.metadata_temperature: Optional[float] = None
103
- self.codes_temperature: Optional[float] = None
104
-
105
- # Duration constraint for codes generation
106
- # 5 codes = 1 second, so target_codes = target_duration * 5
107
- self.target_duration: Optional[float] = None # User-specified duration in seconds
108
- self.target_codes: Optional[int] = None # Computed target codes count
109
- self.codes_count: int = 0 # Counter for generated codes
110
-
111
- # Current state
112
- self.state = FSMState.THINK_TAG
113
- self.position_in_state = 0 # Position within current state's fixed string
114
- self.accumulated_value = "" # For numeric/text value accumulation
115
-
116
- # Pre-compute token IDs for efficiency
117
- self._precompute_tokens()
118
-
119
- # Genres vocabulary for constrained decoding
120
- self.genres_vocab_path = genres_vocab_path or os.path.join(
121
- os.path.dirname(os.path.abspath(__file__)), "genres_vocab.txt"
122
- )
123
- self.genres_vocab: List[str] = [] # Full vocab
124
- self.genres_vocab_mtime: float = 0.0
125
- self.genres_trie: Dict = {} # Trie for full vocab (fallback)
126
- self.caption_genres_trie: Dict = {} # Trie for caption-matched genres (priority)
127
- self.caption_matched_genres: List[str] = [] # Genres matched from caption
128
- self._char_to_tokens: Dict[str, set] = {} # Precomputed char -> token IDs mapping
129
-
130
- # Precompute token mappings once (O(vocab_size), runs once at init)
131
- self._precompute_char_token_mapping()
132
- self._load_genres_vocab()
133
-
134
- # Note: Caption-based genre filtering is initialized via update_caption() before each generation
135
-
136
- # Field definitions
137
- self.field_specs = {
138
- "bpm": {"min": 30, "max": 300},
139
- "duration": {"min": 10, "max": 600},
140
- "timesignature": {"valid_values": [2, 3, 4, 6]},
141
- }
142
-
143
- # Fixed strings for each state
144
- self.fixed_strings = {
145
- FSMState.THINK_TAG: "<think>",
146
- FSMState.NEWLINE_AFTER_THINK: "\n",
147
- FSMState.BPM_NAME: "bpm: ",
148
- FSMState.NEWLINE_AFTER_BPM: "\n",
149
- FSMState.DURATION_NAME: "duration: ",
150
- FSMState.NEWLINE_AFTER_DURATION: "\n",
151
- FSMState.GENRES_NAME: "genres: ",
152
- FSMState.NEWLINE_AFTER_GENRES: "\n",
153
- FSMState.KEYSCALE_NAME: "keyscale: ",
154
- FSMState.NEWLINE_AFTER_KEYSCALE: "\n",
155
- FSMState.TIMESIG_NAME: "timesignature: ",
156
- FSMState.NEWLINE_AFTER_TIMESIG: "\n",
157
- FSMState.THINK_END_TAG: "</think>",
158
- }
159
-
160
- # State transitions - build dynamically based on skip_genres
161
- self._build_state_transitions()
162
-
163
- def _build_state_transitions(self):
164
- """Build state transition map based on skip_genres setting."""
165
- self.next_state = {
166
- FSMState.THINK_TAG: FSMState.NEWLINE_AFTER_THINK,
167
- FSMState.NEWLINE_AFTER_THINK: FSMState.BPM_NAME,
168
- FSMState.BPM_NAME: FSMState.BPM_VALUE,
169
- FSMState.BPM_VALUE: FSMState.NEWLINE_AFTER_BPM,
170
- FSMState.NEWLINE_AFTER_BPM: FSMState.DURATION_NAME,
171
- FSMState.DURATION_NAME: FSMState.DURATION_VALUE,
172
- FSMState.DURATION_VALUE: FSMState.NEWLINE_AFTER_DURATION,
173
- FSMState.KEYSCALE_NAME: FSMState.KEYSCALE_VALUE,
174
- FSMState.KEYSCALE_VALUE: FSMState.NEWLINE_AFTER_KEYSCALE,
175
- FSMState.NEWLINE_AFTER_KEYSCALE: FSMState.TIMESIG_NAME,
176
- FSMState.TIMESIG_NAME: FSMState.TIMESIG_VALUE,
177
- FSMState.TIMESIG_VALUE: FSMState.NEWLINE_AFTER_TIMESIG,
178
- FSMState.NEWLINE_AFTER_TIMESIG: FSMState.THINK_END_TAG,
179
- FSMState.THINK_END_TAG: FSMState.CODES_GENERATION,
180
- FSMState.CODES_GENERATION: FSMState.COMPLETED,
181
- }
182
-
183
- if self.skip_genres:
184
- # Skip genres: NEWLINE_AFTER_DURATION -> KEYSCALE_NAME directly
185
- self.next_state[FSMState.NEWLINE_AFTER_DURATION] = FSMState.KEYSCALE_NAME
186
- else:
187
- # Include genres in the flow
188
- self.next_state[FSMState.NEWLINE_AFTER_DURATION] = FSMState.GENRES_NAME
189
- self.next_state[FSMState.GENRES_NAME] = FSMState.GENRES_VALUE
190
- self.next_state[FSMState.GENRES_VALUE] = FSMState.NEWLINE_AFTER_GENRES
191
- self.next_state[FSMState.NEWLINE_AFTER_GENRES] = FSMState.KEYSCALE_NAME
192
-
193
- def set_skip_genres(self, skip: bool):
194
- """Set whether to skip genres generation and rebuild state transitions."""
195
- self.skip_genres = skip
196
- self._build_state_transitions()
197
-
198
- def _precompute_tokens(self):
199
- """Pre-compute commonly used token IDs for efficiency."""
200
- # Digit tokens (0-9)
201
- self.digit_tokens = {}
202
- for d in range(10):
203
- tokens = self.tokenizer.encode(str(d), add_special_tokens=False)
204
- if tokens:
205
- self.digit_tokens[d] = tokens[-1] # Take last token (in case of prefix)
206
-
207
- # Newline token
208
- newline_tokens = self.tokenizer.encode("\n", add_special_tokens=False)
209
- self.newline_token = newline_tokens[-1] if newline_tokens else None
210
-
211
- # Note tokens for keyscale (A-G)
212
- self.note_tokens = {}
213
- for note in "ABCDEFG":
214
- tokens = self.tokenizer.encode(note, add_special_tokens=False)
215
- if tokens:
216
- self.note_tokens[note] = tokens[-1]
217
-
218
- # Sharp/flat tokens
219
- self.sharp_tokens = []
220
- for s in ["#", "♯"]:
221
- tokens = self.tokenizer.encode(s, add_special_tokens=False)
222
- if tokens:
223
- self.sharp_tokens.append(tokens[-1])
224
-
225
- self.flat_tokens = []
226
- for f in ["b", "♭"]:
227
- tokens = self.tokenizer.encode(f, add_special_tokens=False)
228
- if tokens:
229
- self.flat_tokens.append(tokens[-1])
230
-
231
- # Space token
232
- space_tokens = self.tokenizer.encode(" ", add_special_tokens=False)
233
- self.space_token = space_tokens[-1] if space_tokens else None
234
-
235
- # Major/minor tokens (we'll encode the full words)
236
- self.major_start_tokens = []
237
- self.minor_start_tokens = []
238
- for prefix in ["m", "M"]:
239
- tokens = self.tokenizer.encode(prefix, add_special_tokens=False)
240
- if tokens:
241
- if prefix.lower() == "m":
242
- self.minor_start_tokens.append(tokens[-1])
243
- self.major_start_tokens.append(tokens[-1]) # "major" also starts with m
244
-
245
- # Vocab size
246
- self.vocab_size = len(self.tokenizer)
247
-
248
- # Comma token for multi-genre support
249
- comma_tokens = self.tokenizer.encode(",", add_special_tokens=False)
250
- self.comma_token = comma_tokens[-1] if comma_tokens else None
251
-
252
- # EOS token for duration-constrained codes generation
253
- self.eos_token_id = self.tokenizer.eos_token_id
254
-
255
- # Build valid keyscales set and prefix tree for constrained decoding
256
- # 7 notes × 5 accidentals (none, #, b, ♯, ♭) × 2 modes = 70 valid combinations
257
- notes = ['A', 'B', 'C', 'D', 'E', 'F', 'G']
258
- accidentals = ['', '#', 'b', '♯', '♭'] # empty + ASCII sharp/flat + Unicode sharp/flat
259
- modes = ['major', 'minor']
260
-
261
- self.valid_keyscales = set()
262
- for note in notes:
263
- for acc in accidentals:
264
- for mode in modes:
265
- self.valid_keyscales.add(f"{note}{acc} {mode}")
266
-
267
- # Build prefix tree for keyscale constrained decoding
268
- self.keyscale_prefix_tree = self._build_keyscale_prefix_tree()
269
-
270
- def _build_keyscale_prefix_tree(self) -> Dict[str, Set[int]]:
271
- """
272
- Build keyscale prefix to allowed tokens mapping.
273
- For each prefix of each valid keyscale, we store the set of tokens
274
- that can continue to form a valid keyscale.
275
- """
276
- prefix_to_tokens: Dict[str, Set[int]] = {}
277
-
278
- for keyscale in self.valid_keyscales:
279
- for i in range(len(keyscale)):
280
- prefix = keyscale[:i]
281
- next_char = keyscale[i]
282
- # Encode the next character
283
- tokens = self.tokenizer.encode(next_char, add_special_tokens=False)
284
- if prefix not in prefix_to_tokens:
285
- prefix_to_tokens[prefix] = set()
286
- prefix_to_tokens[prefix].update(tokens)
287
-
288
- # For complete keyscales, allow newline token
289
- for keyscale in self.valid_keyscales:
290
- if keyscale not in prefix_to_tokens:
291
- prefix_to_tokens[keyscale] = set()
292
- if self.newline_token:
293
- prefix_to_tokens[keyscale].add(self.newline_token)
294
-
295
- if self.debug:
296
- logger.debug(f"Built keyscale prefix tree with {len(prefix_to_tokens)} prefixes for {len(self.valid_keyscales)} valid keyscales")
297
-
298
- return prefix_to_tokens
299
-
300
- def _load_genres_vocab(self):
301
- """
302
- Load genres vocabulary from file. Supports hot reload by checking file mtime.
303
- File format: one genre per line, lines starting with # are comments.
304
- """
305
- if not os.path.exists(self.genres_vocab_path):
306
- if self.debug:
307
- logger.debug(f"Genres vocab file not found: {self.genres_vocab_path}")
308
- return
309
-
310
- try:
311
- mtime = os.path.getmtime(self.genres_vocab_path)
312
- if mtime <= self.genres_vocab_mtime:
313
- return # File hasn't changed
314
-
315
- with open(self.genres_vocab_path, 'r', encoding='utf-8') as f:
316
- genres = []
317
- for line in f:
318
- line = line.strip()
319
- if line and not line.startswith('#'):
320
- genres.append(line.lower())
321
-
322
- self.genres_vocab = genres
323
- self.genres_vocab_mtime = mtime
324
- self._build_genres_trie()
325
-
326
- if self.debug:
327
- logger.debug(f"Loaded {len(self.genres_vocab)} genres from {self.genres_vocab_path}")
328
- except Exception as e:
329
- logger.warning(f"Failed to load genres vocab: {e}")
330
-
331
- def _build_genres_trie(self):
332
- """
333
- Build a trie (prefix tree) from genres vocabulary for efficient prefix matching.
334
- Each node is a dict with:
335
- - '_end': True if this node represents a complete genre
336
- - other keys: next characters in the trie
337
- """
338
- self.genres_trie = {}
339
-
340
- for genre in self.genres_vocab:
341
- node = self.genres_trie
342
- for char in genre:
343
- if char not in node:
344
- node[char] = {}
345
- node = node[char]
346
- node['_end'] = True # Mark end of a complete genre
347
-
348
- if self.debug:
349
- logger.debug(f"Built genres trie with {len(self.genres_vocab)} entries")
350
-
351
- def _extract_caption_genres(self, caption: str):
352
- """
353
- Extract genres from the user's caption that match entries in the vocabulary.
354
- This creates a smaller trie for faster and more relevant genre generation.
355
-
356
- Strategy (optimized - O(words * max_genre_len) instead of O(vocab_size)):
357
- 1. Extract words/phrases from caption
358
- 2. For each word, use trie to find all vocab entries that START with this word
359
- 3. Build a separate trie from matched genres
360
- """
361
- if not caption or not self.genres_vocab:
362
- return
363
-
364
- caption_lower = caption.lower()
365
- matched_genres = set()
366
-
367
- # Extract words from caption (split by common delimiters)
368
- import re
369
- words = re.split(r'[,\s\-_/\\|]+', caption_lower)
370
- words = [w.strip() for w in words if w.strip() and len(w.strip()) >= 2]
371
-
372
- # For each word, find genres in trie that start with this word
373
- for word in words:
374
- # Find all genres starting with this word using trie traversal
375
- node = self._get_genres_trie_node(word)
376
- if node is not None:
377
- # Collect all complete genres under this node
378
- self._collect_complete_genres(node, word, matched_genres)
379
-
380
- # Also check if any word appears as a substring in short genres (< 20 chars)
381
- # This is a quick check for common single-word genres
382
- genres_set = set(self.genres_vocab)
383
- for word in words:
384
- if word in genres_set:
385
- matched_genres.add(word)
386
-
387
- if not matched_genres:
388
- if self.debug:
389
- logger.debug(f"No genres matched in caption, using full vocab")
390
- return
391
-
392
- # Build a trie from matched genres
393
- self.caption_matched_genres = list(matched_genres)
394
- self.caption_genres_trie = {}
395
-
396
- for genre in matched_genres:
397
- node = self.caption_genres_trie
398
- for char in genre:
399
- if char not in node:
400
- node[char] = {}
401
- node = node[char]
402
- node['_end'] = True
403
-
404
- if self.debug:
405
- logger.debug(f"Matched {len(matched_genres)} genres from caption: {list(matched_genres)[:5]}...")
406
-
407
- def _collect_complete_genres(self, node: Dict, prefix: str, result: set, max_depth: int = 50):
408
- """
409
- Recursively collect all complete genres under a trie node.
410
- Limited depth to avoid too many matches.
411
- """
412
- if max_depth <= 0:
413
- return
414
-
415
- if node.get('_end', False):
416
- result.add(prefix)
417
-
418
- # Limit total collected genres to avoid slowdown
419
- if len(result) >= 100:
420
- return
421
-
422
- for char, child_node in node.items():
423
- if char not in ('_end', '_tokens'):
424
- self._collect_complete_genres(child_node, prefix + char, result, max_depth - 1)
425
-
426
- def _precompute_char_token_mapping(self):
427
- """
428
- Precompute mapping from characters to token IDs and token decoded texts.
429
- This allows O(1) lookup instead of calling tokenizer.encode()/decode() at runtime.
430
-
431
- Time complexity: O(vocab_size) - runs once during initialization
432
-
433
- Note: Many subword tokenizers (like Qwen) add space prefixes to tokens.
434
- We need to handle both the raw first char and the first non-space char.
435
- """
436
- self._char_to_tokens: Dict[str, set] = {}
437
- self._token_to_text: Dict[int, str] = {} # Precomputed decoded text for each token
438
-
439
- # For each token in vocabulary, get its decoded text
440
- for token_id in range(self.vocab_size):
441
- try:
442
- text = self.tokenizer.decode([token_id])
443
-
444
- if not text:
445
- continue
446
-
447
- # Store the decoded text (normalized to lowercase)
448
- # Keep leading spaces for proper concatenation (e.g., " rock" in "pop rock")
449
- # Only rstrip trailing whitespace, unless it's a pure whitespace token
450
- text_lower = text.lower()
451
- if text_lower.strip(): # Has non-whitespace content
452
- normalized_text = text_lower.rstrip()
453
- else: # Pure whitespace token
454
- normalized_text = " " # Normalize to single space
455
- self._token_to_text[token_id] = normalized_text
456
-
457
- # Map first character (including space) to this token
458
- first_char = text[0].lower()
459
- if first_char not in self._char_to_tokens:
460
- self._char_to_tokens[first_char] = set()
461
- self._char_to_tokens[first_char].add(token_id)
462
-
463
- # Also map first non-space character to this token
464
- # This handles tokenizers that add space prefixes (e.g., " pop" -> maps to 'p')
465
- stripped_text = text.lstrip()
466
- if stripped_text and stripped_text != text:
467
- first_nonspace_char = stripped_text[0].lower()
468
- if first_nonspace_char not in self._char_to_tokens:
469
- self._char_to_tokens[first_nonspace_char] = set()
470
- self._char_to_tokens[first_nonspace_char].add(token_id)
471
-
472
- except Exception:
473
- continue
474
-
475
- if self.debug:
476
- logger.debug(f"Precomputed char->token mapping for {len(self._char_to_tokens)} unique characters")
477
-
478
- def _try_reload_genres_vocab(self):
479
- """Check if genres vocab file has been updated and reload if necessary."""
480
- if not os.path.exists(self.genres_vocab_path):
481
- return
482
-
483
- try:
484
- mtime = os.path.getmtime(self.genres_vocab_path)
485
- if mtime > self.genres_vocab_mtime:
486
- self._load_genres_vocab()
487
- except Exception:
488
- pass # Ignore errors during hot reload check
489
-
490
- def _get_genres_trie_node(self, prefix: str) -> Optional[Dict]:
491
- """
492
- Get the trie node for a given prefix.
493
- Returns None if the prefix is not valid (no genres start with this prefix).
494
- """
495
- node = self.genres_trie
496
- for char in prefix.lower():
497
- if char not in node:
498
- return None
499
- node = node[char]
500
- return node
501
-
502
- def _is_complete_genre(self, text: str) -> bool:
503
- """Check if the given text is a complete genre in the vocabulary."""
504
- node = self._get_genres_trie_node(text.strip())
505
- return node is not None and node.get('_end', False)
506
-
507
- def _get_trie_node_from_trie(self, trie: Dict, prefix: str) -> Optional[Dict]:
508
- """Get a trie node from a specific trie (helper for caption vs full trie)."""
509
- node = trie
510
- for char in prefix.lower():
511
- if char not in node:
512
- return None
513
- node = node[char]
514
- return node
515
-
516
- def _get_allowed_genres_tokens(self) -> List[int]:
517
- """
518
- Get allowed tokens for genres field based on trie matching.
519
-
520
- The entire genres string (including commas) must match a complete entry in the vocab.
521
- For example, if vocab contains "pop, rock, jazz", the generated string must exactly
522
- match that entry - we don't treat commas as separators for individual genres.
523
-
524
- Strategy:
525
- 1. If caption-matched genres exist, use that smaller trie first (faster + more relevant)
526
- 2. If no caption matches or prefix not in caption trie, fallback to full vocab trie
527
- 3. Get valid next characters from current trie node
528
- 4. For each candidate token, verify the full decoded text forms a valid trie prefix
529
- """
530
- if not self.genres_vocab:
531
- # No vocab loaded, allow all except newline if empty
532
- return []
533
-
534
- # Use the full accumulated value (don't split by comma - treat as single entry)
535
- accumulated = self.accumulated_value.lower()
536
- current_genre_prefix = accumulated.strip()
537
-
538
- # Determine which trie to use: caption-matched (priority) or full vocab (fallback)
539
- use_caption_trie = False
540
- current_node = None
541
-
542
- # Try caption-matched trie first if available
543
- if self.caption_genres_trie:
544
- if current_genre_prefix == "":
545
- current_node = self.caption_genres_trie
546
- use_caption_trie = True
547
- else:
548
- current_node = self._get_trie_node_from_trie(self.caption_genres_trie, current_genre_prefix)
549
- if current_node is not None:
550
- use_caption_trie = True
551
-
552
- # Fallback to full vocab trie
553
- if current_node is None:
554
- if current_genre_prefix == "":
555
- current_node = self.genres_trie
556
- else:
557
- current_node = self._get_genres_trie_node(current_genre_prefix)
558
-
559
- if current_node is None:
560
- # Invalid prefix, force newline to end
561
- if self.newline_token:
562
- return [self.newline_token]
563
- return []
564
-
565
- # Get valid next characters from trie node
566
- valid_next_chars = set(k for k in current_node.keys() if k not in ('_end', '_tokens'))
567
-
568
- # If current value is a complete genre, allow newline to end
569
- is_complete = current_node.get('_end', False)
570
-
571
- if not valid_next_chars:
572
- # No more characters to match, only allow newline if complete
573
- allowed = set()
574
- if is_complete and self.newline_token:
575
- allowed.add(self.newline_token)
576
- return list(allowed)
577
-
578
- # Collect candidate tokens based on first character
579
- candidate_tokens = set()
580
- for char in valid_next_chars:
581
- if char in self._char_to_tokens:
582
- candidate_tokens.update(self._char_to_tokens[char])
583
-
584
- # Select the appropriate trie for validation
585
- active_trie = self.caption_genres_trie if use_caption_trie else self.genres_trie
586
-
587
- # Validate each candidate token: check if prefix + decoded_token is a valid trie prefix
588
- allowed = set()
589
- for token_id in candidate_tokens:
590
- # Use precomputed decoded text (already normalized)
591
- decoded_normalized = self._token_to_text.get(token_id, "")
592
-
593
- if not decoded_normalized or not decoded_normalized.strip():
594
- # Token decodes to empty or only whitespace - allow if space/comma is a valid next char
595
- if ' ' in valid_next_chars or ',' in valid_next_chars:
596
- allowed.add(token_id)
597
- continue
598
-
599
- # Build new prefix by appending decoded token
600
- # Handle space-prefixed tokens (e.g., " rock" from "pop rock")
601
- if decoded_normalized.startswith(' ') or decoded_normalized.startswith(','):
602
- # Token has leading space/comma - append directly
603
- new_prefix = current_genre_prefix + decoded_normalized
604
- else:
605
- new_prefix = current_genre_prefix + decoded_normalized
606
-
607
- # Check if new_prefix is a valid prefix in the active trie
608
- new_node = self._get_trie_node_from_trie(active_trie, new_prefix)
609
- if new_node is not None:
610
- allowed.add(token_id)
611
-
612
- # If current value is a complete genre, also allow newline
613
- if is_complete and self.newline_token:
614
- allowed.add(self.newline_token)
615
-
616
- return list(allowed)
617
-
618
- def reset(self):
619
- """Reset the processor state for a new generation."""
620
- self.state = FSMState.THINK_TAG
621
- self.position_in_state = 0
622
- self.accumulated_value = ""
623
- self.codes_count = 0 # Reset codes counter
624
-
625
- def set_target_duration(self, duration: Optional[float]):
626
- """
627
- Set the target duration for codes generation.
628
-
629
- Args:
630
- duration: Target duration in seconds. If None, no duration constraint is applied.
631
- 5 codes = 1 second, so target_codes = duration * 5.
632
- """
633
- self.target_duration = duration
634
- if duration is not None and duration > 0:
635
- self.target_codes = int(duration * 5)
636
- if self.debug:
637
- logger.debug(f"Set target duration: {duration}s -> {self.target_codes} codes")
638
- else:
639
- self.target_codes = None
640
- if self.debug:
641
- logger.debug("Target duration cleared, no duration constraint")
642
-
643
- def update_caption(self, caption: Optional[str]):
644
- """
645
- Update the caption and rebuild the caption-matched genres trie.
646
- Call this before each generation to prioritize genres from the new caption.
647
-
648
- Args:
649
- caption: User's input caption. If None or empty, clears caption matching.
650
- """
651
- # Check for hot reload of genres vocabulary
652
- self._try_reload_genres_vocab()
653
-
654
- self.caption = caption
655
- self.caption_genres_trie = {}
656
- self.caption_matched_genres = []
657
-
658
- if caption:
659
- self._extract_caption_genres(caption)
660
-
661
- # Also reset FSM state for new generation
662
- self.reset()
663
-
664
- def _get_allowed_tokens_for_fixed_string(self, fixed_str: str) -> List[int]:
665
- """
666
- Get the token IDs that can continue the fixed string from current position.
667
- Returns list of allowed token IDs.
668
- """
669
- remaining = fixed_str[self.position_in_state:]
670
- if not remaining:
671
- return []
672
-
673
- # Try to find tokens that match the beginning of remaining string
674
- allowed = []
675
-
676
- # Try encoding progressively longer prefixes
677
- for end in range(1, len(remaining) + 1):
678
- prefix = remaining[:end]
679
- tokens = self.tokenizer.encode(prefix, add_special_tokens=False)
680
- if tokens:
681
- # The first token that matches is valid
682
- allowed.append(tokens[0])
683
-
684
- # Also check single character encoding
685
- first_char = remaining[0]
686
- char_tokens = self.tokenizer.encode(first_char, add_special_tokens=False)
687
- if char_tokens:
688
- allowed.extend(char_tokens)
689
-
690
- return list(set(allowed))
691
-
692
- def _get_allowed_digit_tokens(self, min_val: int, max_val: int) -> List[int]:
693
- """
694
- Get allowed digit tokens based on accumulated value and range constraints.
695
- Uses early-blocking to prevent out-of-range values.
696
- """
697
- if not self.accumulated_value:
698
- # First digit: determine valid starting digits
699
- allowed_digits = set()
700
- for v in range(min_val, max_val + 1):
701
- allowed_digits.add(int(str(v)[0]))
702
- return [self.digit_tokens[d] for d in allowed_digits if d in self.digit_tokens]
703
-
704
- current = int(self.accumulated_value)
705
- allowed = []
706
-
707
- for d in range(10):
708
- new_value = int(self.accumulated_value + str(d))
709
- # Check if this digit could lead to a valid final value
710
- # A digit is valid if:
711
- # 1. new_value <= max_val (not already exceeded)
712
- # 2. new_value could potentially reach >= min_val
713
- # (i.e., new_value * 10^k >= min_val for some k >= 0)
714
-
715
- if new_value > max_val:
716
- continue # Already exceeded max
717
-
718
- # Check if we can still reach min_val
719
- # If new_value is already >= min_val, it's valid
720
- # If new_value < min_val, we need more digits, but new_value * 10 must not exceed max
721
- if new_value >= min_val:
722
- allowed.append(d)
723
- elif new_value * 10 <= max_val:
724
- # Can add more digits
725
- allowed.append(d)
726
-
727
- return [self.digit_tokens[d] for d in allowed if d in self.digit_tokens]
728
-
729
- def _should_end_numeric_field(self, logits: torch.Tensor, min_val: int, max_val: int) -> bool:
730
- """
731
- Determine if we should end the current numeric field.
732
- Returns True if P(newline) > P(any valid digit) AND current value is valid.
733
- """
734
- if not self.accumulated_value:
735
- return False
736
-
737
- current = int(self.accumulated_value)
738
- if current < min_val or current > max_val:
739
- return False # Can't end yet, value not in range
740
-
741
- # Get probabilities
742
- probs = torch.softmax(logits, dim=-1)
743
-
744
- newline_prob = probs[0, self.newline_token].item() if self.newline_token else 0
745
-
746
- # Get max probability among valid digit tokens
747
- allowed_digits = self._get_allowed_digit_tokens(min_val, max_val)
748
- if not allowed_digits:
749
- return True # No more digits possible, must end
750
-
751
- max_digit_prob = max(probs[0, t].item() for t in allowed_digits)
752
-
753
- if self.debug:
754
- logger.debug(f"Numeric field decision: newline_prob={newline_prob:.4f}, max_digit_prob={max_digit_prob:.4f}")
755
-
756
- return newline_prob > max_digit_prob
757
-
758
- def _should_end_text_field(self, logits: torch.Tensor) -> bool:
759
- """
760
- Determine if we should end a text field (genres).
761
- Returns True if P(newline) > P(any other token) AND we have some content.
762
- """
763
- if not self.accumulated_value.strip():
764
- return False # Need at least some content
765
-
766
- probs = torch.softmax(logits, dim=-1)
767
- newline_prob = probs[0, self.newline_token].item() if self.newline_token else 0
768
-
769
- # Get max probability among non-newline tokens
770
- masked_probs = probs.clone()
771
- if self.newline_token:
772
- masked_probs[0, self.newline_token] = 0
773
- max_other_prob = masked_probs[0].max().item()
774
-
775
- return newline_prob > max_other_prob
776
-
777
- def _get_allowed_keyscale_tokens(self) -> List[int]:
778
- """
779
- Get allowed tokens for keyscale field using prefix tree.
780
- Only allows tokens that can lead to valid keyscales like:
781
- - "A major", "A minor", "A# major", "Ab minor", etc.
782
- """
783
- acc = self.accumulated_value
784
-
785
- if acc in self.keyscale_prefix_tree:
786
- return list(self.keyscale_prefix_tree[acc])
787
-
788
- # No valid continuation found - return empty list
789
- # The caller will handle this by forcing newline to end the field
790
- return []
791
-
792
- def _is_keyscale_complete(self) -> bool:
793
- """Check if keyscale value is complete and valid by checking against valid_keyscales set."""
794
- return self.accumulated_value in self.valid_keyscales
795
-
796
- def _get_allowed_timesig_tokens(self) -> List[int]:
797
- """Get allowed tokens for timesignature field."""
798
- valid_values = self.field_specs["timesignature"]["valid_values"]
799
-
800
- if not self.accumulated_value:
801
- # First digit: must be 2, 3, 4, or 6
802
- return [self.digit_tokens[d] for d in valid_values if d in self.digit_tokens]
803
-
804
- # Already have a digit, should end
805
- return []
806
-
807
- def __call__(
808
- self,
809
- input_ids: torch.LongTensor,
810
- scores: torch.FloatTensor,
811
- ) -> torch.FloatTensor:
812
- """
813
- Apply constrained decoding by modifying logits.
814
-
815
- Args:
816
- input_ids: [batch_size, seq_len] input token IDs
817
- scores: [batch_size, vocab_size] logits for next token
818
-
819
- Returns:
820
- Modified scores with invalid tokens masked to -inf and temperature scaling applied
821
- """
822
- if not self.enabled:
823
- return self._apply_temperature_scaling(scores)
824
-
825
- if self.state == FSMState.COMPLETED:
826
- return self._apply_temperature_scaling(scores)
827
-
828
- if self.state == FSMState.CODES_GENERATION:
829
- # Apply duration constraint in codes generation phase
830
- if self.target_codes is not None and self.eos_token_id is not None:
831
- if self.codes_count < self.target_codes:
832
- # Block EOS token until target codes count is reached
833
- scores[:, self.eos_token_id] = float('-inf')
834
- if self.debug:
835
- logger.debug(f"Codes generation: {self.codes_count}/{self.target_codes}, blocking EOS")
836
- else:
837
- # Force EOS token when target codes count is reached
838
- mask = torch.full_like(scores, float('-inf'))
839
- mask[:, self.eos_token_id] = 0
840
- scores = scores + mask
841
- if self.debug:
842
- logger.debug(f"Codes generation: {self.codes_count}/{self.target_codes}, forcing EOS")
843
- return self._apply_temperature_scaling(scores)
844
-
845
- batch_size = scores.shape[0]
846
-
847
- # Process each sequence in batch
848
- for b in range(batch_size):
849
- scores[b] = self._process_single_sequence(input_ids[b], scores[b:b+1])
850
-
851
- # Apply temperature scaling after constraint masking
852
- return self._apply_temperature_scaling(scores)
853
-
854
- def _apply_temperature_scaling(self, scores: torch.FloatTensor) -> torch.FloatTensor:
855
- """
856
- Apply temperature scaling based on current generation phase.
857
-
858
- Temperature scaling: logits = logits / temperature
859
- - Lower temperature (< 1.0) makes distribution sharper (more deterministic)
860
- - Higher temperature (> 1.0) makes distribution flatter (more diverse)
861
-
862
- Args:
863
- scores: [batch_size, vocab_size] logits
864
-
865
- Returns:
866
- Temperature-scaled logits
867
- """
868
- # Determine which temperature to use based on current state
869
- if self.state == FSMState.CODES_GENERATION or self.state == FSMState.COMPLETED:
870
- temperature = self.codes_temperature
871
- else:
872
- temperature = self.metadata_temperature
873
-
874
- # If no temperature is set for this phase, return scores unchanged
875
- if temperature is None:
876
- return scores
877
-
878
- # Avoid division by zero
879
- if temperature <= 0:
880
- temperature = 1e-6
881
-
882
- # Apply temperature scaling
883
- return scores / temperature
884
-
885
- def _process_single_sequence(
886
- self,
887
- input_ids: torch.LongTensor,
888
- scores: torch.FloatTensor,
889
- ) -> torch.FloatTensor:
890
- """Process a single sequence and return modified scores."""
891
-
892
- # Create mask (all -inf initially)
893
- mask = torch.full_like(scores, float('-inf'))
894
-
895
- if self.state in self.fixed_strings:
896
- # Fixed string state: force specific tokens
897
- allowed = self._get_allowed_tokens_for_fixed_string(self.fixed_strings[self.state])
898
- if allowed:
899
- for t in allowed:
900
- mask[0, t] = 0
901
- # Apply mask
902
- scores = scores + mask
903
-
904
- # Update position tracking
905
- # We need to check if the selected token completes the fixed string
906
- # This will be done in update_state() after token selection
907
- else:
908
- # Position exceeds string, move to next state
909
- self._transition_to_next_state()
910
- return self._process_single_sequence(input_ids, torch.zeros_like(scores))
911
-
912
- elif self.state == FSMState.BPM_VALUE:
913
- min_val, max_val = self.field_specs["bpm"]["min"], self.field_specs["bpm"]["max"]
914
-
915
- # Check if we should end the field
916
- if self._should_end_numeric_field(scores, min_val, max_val):
917
- # Force newline
918
- if self.newline_token:
919
- mask[0, self.newline_token] = 0
920
- self._transition_to_next_state()
921
- else:
922
- # Allow valid digits
923
- allowed = self._get_allowed_digit_tokens(min_val, max_val)
924
- for t in allowed:
925
- mask[0, t] = 0
926
- # Also allow newline if current value is valid
927
- current = int(self.accumulated_value) if self.accumulated_value else 0
928
- if min_val <= current <= max_val and self.newline_token:
929
- mask[0, self.newline_token] = 0
930
-
931
- scores = scores + mask
932
-
933
- elif self.state == FSMState.DURATION_VALUE:
934
- # If target_duration is set, force generate that exact value
935
- if self.target_duration is not None:
936
- target_str = str(int(self.target_duration))
937
- current_pos = len(self.accumulated_value)
938
-
939
- if current_pos < len(target_str):
940
- # Force the next digit
941
- next_digit = int(target_str[current_pos])
942
- if next_digit in self.digit_tokens:
943
- mask[0, self.digit_tokens[next_digit]] = 0
944
- else:
945
- # All digits generated, force newline
946
- if self.newline_token:
947
- mask[0, self.newline_token] = 0
948
- self._transition_to_next_state()
949
-
950
- scores = scores + mask
951
- else:
952
- # Normal duration generation with range constraint
953
- min_val, max_val = self.field_specs["duration"]["min"], self.field_specs["duration"]["max"]
954
-
955
- if self._should_end_numeric_field(scores, min_val, max_val):
956
- if self.newline_token:
957
- mask[0, self.newline_token] = 0
958
- self._transition_to_next_state()
959
- else:
960
- allowed = self._get_allowed_digit_tokens(min_val, max_val)
961
- for t in allowed:
962
- mask[0, t] = 0
963
- current = int(self.accumulated_value) if self.accumulated_value else 0
964
- if min_val <= current <= max_val and self.newline_token:
965
- mask[0, self.newline_token] = 0
966
-
967
- scores = scores + mask
968
-
969
- elif self.state == FSMState.GENRES_VALUE:
970
- # Try to hot-reload genres vocab if file has changed
971
- self._try_reload_genres_vocab()
972
-
973
- # Get allowed tokens based on genres vocabulary
974
- allowed = self._get_allowed_genres_tokens()
975
-
976
- if allowed:
977
- # Use vocabulary-constrained decoding
978
- for t in allowed:
979
- mask[0, t] = 0
980
- scores = scores + mask
981
- elif self.genres_vocab:
982
- # Vocab is loaded but no valid continuation found
983
- # Force newline to end the field
984
- if self.newline_token:
985
- mask[0, self.newline_token] = 0
986
- if self.debug:
987
- logger.debug(f"No valid genre continuation for '{self.accumulated_value}', forcing newline")
988
- scores = scores + mask
989
- else:
990
- # Fallback: no vocab loaded, use probability-based ending
991
- if self._should_end_text_field(scores):
992
- if self.newline_token:
993
- mask[0, self.newline_token] = 0
994
- self._transition_to_next_state()
995
- scores = scores + mask
996
- else:
997
- # Allow any token except newline if we don't have content yet
998
- if not self.accumulated_value.strip():
999
- if self.newline_token:
1000
- scores[0, self.newline_token] = float('-inf')
1001
- # Otherwise, don't constrain (fallback behavior)
1002
-
1003
- elif self.state == FSMState.KEYSCALE_VALUE:
1004
- if self._is_keyscale_complete():
1005
- # Force newline to end
1006
- if self.newline_token:
1007
- mask[0, self.newline_token] = 0
1008
- self._transition_to_next_state()
1009
- scores = scores + mask
1010
- else:
1011
- allowed = self._get_allowed_keyscale_tokens()
1012
- if allowed:
1013
- for t in allowed:
1014
- mask[0, t] = 0
1015
- scores = scores + mask
1016
- else:
1017
- # No valid tokens found - force newline to end field
1018
- # This handles edge cases where keyscale format is unexpected
1019
- if self.newline_token:
1020
- mask[0, self.newline_token] = 0
1021
- self._transition_to_next_state()
1022
- scores = scores + mask
1023
-
1024
- elif self.state == FSMState.TIMESIG_VALUE:
1025
- if self.accumulated_value:
1026
- # Already have a digit, force newline
1027
- if self.newline_token:
1028
- mask[0, self.newline_token] = 0
1029
- self._transition_to_next_state()
1030
- scores = scores + mask
1031
- else:
1032
- allowed = self._get_allowed_timesig_tokens()
1033
- for t in allowed:
1034
- mask[0, t] = 0
1035
- scores = scores + mask
1036
-
1037
- return scores
1038
-
1039
- def _transition_to_next_state(self):
1040
- """Transition to the next FSM state."""
1041
- if self.state in self.next_state:
1042
- old_state = self.state
1043
- self.state = self.next_state[self.state]
1044
- self.position_in_state = 0
1045
- self.accumulated_value = ""
1046
- if self.debug:
1047
- logger.debug(f"FSM transition: {old_state.name} -> {self.state.name}")
1048
-
1049
- def update_state(self, generated_token_id: int):
1050
- """
1051
- Update internal state after a token has been generated.
1052
- This should be called after each token generation.
1053
-
1054
- Args:
1055
- generated_token_id: The token ID that was just generated
1056
- """
1057
- if not self.enabled:
1058
- return
1059
-
1060
- if self.state == FSMState.COMPLETED:
1061
- return
1062
-
1063
- if self.state == FSMState.CODES_GENERATION:
1064
- # Count generated codes for duration constraint
1065
- self.codes_count += 1
1066
- if self.debug and self.target_codes is not None:
1067
- logger.debug(f"Codes count: {self.codes_count}/{self.target_codes}")
1068
- return
1069
-
1070
- token_str = self.tokenizer.decode([generated_token_id])
1071
-
1072
- if self.debug:
1073
- logger.debug(f"Generated token: {repr(token_str)} (id={generated_token_id}), state={self.state.name}")
1074
-
1075
- if self.state in self.fixed_strings:
1076
- # Update position in fixed string
1077
- fixed_str = self.fixed_strings[self.state]
1078
- self.position_in_state += len(token_str)
1079
-
1080
- # Check if we've completed the fixed string
1081
- if self.position_in_state >= len(fixed_str):
1082
- self._transition_to_next_state()
1083
-
1084
- elif self.state in [FSMState.BPM_VALUE, FSMState.DURATION_VALUE, FSMState.TIMESIG_VALUE]:
1085
- # Accumulate numeric value
1086
- if token_str.strip().isdigit():
1087
- self.accumulated_value += token_str.strip()
1088
- elif generated_token_id == self.newline_token:
1089
- # Newline ends the field
1090
- self._transition_to_next_state()
1091
-
1092
- elif self.state == FSMState.GENRES_VALUE:
1093
- if generated_token_id == self.newline_token:
1094
- self._transition_to_next_state()
1095
- else:
1096
- self.accumulated_value += token_str
1097
-
1098
- elif self.state == FSMState.KEYSCALE_VALUE:
1099
- if generated_token_id == self.newline_token:
1100
- self._transition_to_next_state()
1101
- else:
1102
- self.accumulated_value += token_str
1103
 
1104
 
1105
  class LLMHandler:
@@ -1429,6 +348,7 @@ class LLMHandler:
1429
  metadata_temperature: Optional[float] = 0.85,
1430
  codes_temperature: Optional[float] = None,
1431
  target_duration: Optional[float] = None,
 
1432
  ) -> str:
1433
  """Shared vllm path: accept prebuilt formatted prompt and return text."""
1434
  from nanovllm import SamplingParams
@@ -1447,6 +367,8 @@ class LLMHandler:
1447
  self.constrained_processor.codes_temperature = codes_temperature if use_phase_temperatures else None
1448
  self.constrained_processor.update_caption(formatted_prompt) # Use formatted prompt for genre extraction
1449
  self.constrained_processor.set_target_duration(target_duration)
 
 
1450
 
1451
  constrained_processor = self.constrained_processor
1452
 
@@ -1701,6 +623,7 @@ class LLMHandler:
1701
  use_constrained_decoding: bool = True,
1702
  constrained_decoding_debug: bool = False,
1703
  target_duration: Optional[float] = None,
 
1704
  ) -> str:
1705
  """Shared PyTorch path: accept prebuilt formatted prompt and return text."""
1706
  inputs = self.llm_tokenizer(
@@ -1718,6 +641,8 @@ class LLMHandler:
1718
  self.constrained_processor.debug = constrained_decoding_debug
1719
  self.constrained_processor.update_caption(formatted_prompt) # Use formatted prompt for genre extraction
1720
  self.constrained_processor.set_target_duration(target_duration)
 
 
1721
 
1722
  constrained_processor = self.constrained_processor
1723
 
@@ -2048,6 +973,7 @@ class LLMHandler:
2048
  top_p = cfg.get("top_p")
2049
  repetition_penalty = cfg.get("repetition_penalty", 1.0)
2050
  target_duration = cfg.get("target_duration")
 
2051
 
2052
  try:
2053
  if self.llm_backend == "vllm":
@@ -2062,6 +988,7 @@ class LLMHandler:
2062
  use_constrained_decoding=use_constrained_decoding,
2063
  constrained_decoding_debug=constrained_decoding_debug,
2064
  target_duration=target_duration,
 
2065
  )
2066
  return output_text, f"✅ Generated successfully (vllm) | length={len(output_text)}"
2067
 
@@ -2077,6 +1004,7 @@ class LLMHandler:
2077
  use_constrained_decoding=use_constrained_decoding,
2078
  constrained_decoding_debug=constrained_decoding_debug,
2079
  target_duration=target_duration,
 
2080
  )
2081
  return output_text, f"✅ Generated successfully (pt) | length={len(output_text)}"
2082
 
 
3
  Handles all LM-related operations including initialization and generation
4
  """
5
  import os
 
6
  import traceback
7
  import time
8
+ from typing import Optional, Dict, Any, Tuple, List
 
9
  from contextlib import contextmanager
10
 
11
  import torch
 
18
  RepetitionPenaltyLogitsProcessor,
19
  LogitsProcessor,
20
  )
21
+ from .constrained_logits_processor import MetadataConstrainedLogitsProcessor
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
 
24
  class LLMHandler:
 
348
  metadata_temperature: Optional[float] = 0.85,
349
  codes_temperature: Optional[float] = None,
350
  target_duration: Optional[float] = None,
351
+ user_metadata: Optional[Dict[str, Optional[str]]] = None,
352
  ) -> str:
353
  """Shared vllm path: accept prebuilt formatted prompt and return text."""
354
  from nanovllm import SamplingParams
 
367
  self.constrained_processor.codes_temperature = codes_temperature if use_phase_temperatures else None
368
  self.constrained_processor.update_caption(formatted_prompt) # Use formatted prompt for genre extraction
369
  self.constrained_processor.set_target_duration(target_duration)
370
+ # Always call set_user_metadata to ensure previous settings are cleared if None
371
+ self.constrained_processor.set_user_metadata(user_metadata)
372
 
373
  constrained_processor = self.constrained_processor
374
 
 
623
  use_constrained_decoding: bool = True,
624
  constrained_decoding_debug: bool = False,
625
  target_duration: Optional[float] = None,
626
+ user_metadata: Optional[Dict[str, Optional[str]]] = None,
627
  ) -> str:
628
  """Shared PyTorch path: accept prebuilt formatted prompt and return text."""
629
  inputs = self.llm_tokenizer(
 
641
  self.constrained_processor.debug = constrained_decoding_debug
642
  self.constrained_processor.update_caption(formatted_prompt) # Use formatted prompt for genre extraction
643
  self.constrained_processor.set_target_duration(target_duration)
644
+ # Always call set_user_metadata to ensure previous settings are cleared if None
645
+ self.constrained_processor.set_user_metadata(user_metadata)
646
 
647
  constrained_processor = self.constrained_processor
648
 
 
973
  top_p = cfg.get("top_p")
974
  repetition_penalty = cfg.get("repetition_penalty", 1.0)
975
  target_duration = cfg.get("target_duration")
976
+ user_metadata = cfg.get("user_metadata") # User-provided metadata fields
977
 
978
  try:
979
  if self.llm_backend == "vllm":
 
988
  use_constrained_decoding=use_constrained_decoding,
989
  constrained_decoding_debug=constrained_decoding_debug,
990
  target_duration=target_duration,
991
+ user_metadata=user_metadata,
992
  )
993
  return output_text, f"✅ Generated successfully (vllm) | length={len(output_text)}"
994
 
 
1004
  use_constrained_decoding=use_constrained_decoding,
1005
  constrained_decoding_debug=constrained_decoding_debug,
1006
  target_duration=target_duration,
1007
+ user_metadata=user_metadata,
1008
  )
1009
  return output_text, f"✅ Generated successfully (pt) | length={len(output_text)}"
1010