| """ |
| Generation Input Handlers Module |
| Contains event handlers and helper functions related to generation inputs |
| """ |
| import os |
| import json |
| import random |
| import glob |
| import gradio as gr |
| from typing import Optional, List, Tuple |
| from acestep.constants import ( |
| TASK_TYPES_TURBO, |
| TASK_TYPES_BASE, |
| ) |
| from acestep.gradio_ui.i18n import t |
| from acestep.inference import understand_music, create_sample, format_sample |
|
|
|
|
| def parse_and_validate_timesteps( |
| timesteps_str: str, |
| inference_steps: int |
| ) -> Tuple[Optional[List[float]], bool, str]: |
| """ |
| Parse timesteps string and validate. |
| |
| Args: |
| timesteps_str: Comma-separated timesteps string (e.g., "0.97,0.76,0.615,0.5,0.395,0.28,0.18,0.085,0") |
| inference_steps: Expected number of inference steps |
| |
| Returns: |
| Tuple of (parsed_timesteps, has_warning, warning_message) |
| - parsed_timesteps: List of float timesteps, or None if invalid/empty |
| - has_warning: Whether a warning was shown |
| - warning_message: Description of the warning |
| """ |
| if not timesteps_str or not timesteps_str.strip(): |
| return None, False, "" |
| |
| |
| values = [v.strip() for v in timesteps_str.split(",") if v.strip()] |
| |
| if not values: |
| return None, False, "" |
| |
| |
| if values[-1] != "0": |
| values.append("0") |
| |
| try: |
| timesteps = [float(v) for v in values] |
| except ValueError: |
| gr.Warning(t("messages.invalid_timesteps_format")) |
| return None, True, "Invalid format" |
| |
| |
| if any(ts < 0 or ts > 1 for ts in timesteps): |
| gr.Warning(t("messages.timesteps_out_of_range")) |
| return None, True, "Out of range" |
| |
| |
| actual_steps = len(timesteps) - 1 |
| if actual_steps != inference_steps: |
| gr.Warning(t("messages.timesteps_count_mismatch", actual=actual_steps, expected=inference_steps)) |
| return timesteps, True, f"Using {actual_steps} steps from timesteps" |
| |
| return timesteps, False, "" |
|
|
|
|
| def load_metadata(file_obj): |
| """Load generation parameters from a JSON file""" |
| if file_obj is None: |
| gr.Warning(t("messages.no_file_selected")) |
| return [None] * 36 + [False] |
| |
| try: |
| |
| if hasattr(file_obj, 'name'): |
| filepath = file_obj.name |
| else: |
| filepath = file_obj |
| |
| with open(filepath, 'r', encoding='utf-8') as f: |
| metadata = json.load(f) |
| |
| |
| task_type = metadata.get('task_type', 'text2music') |
| captions = metadata.get('caption', '') |
| lyrics = metadata.get('lyrics', '') |
| vocal_language = metadata.get('vocal_language', 'unknown') |
| |
| |
| bpm_value = metadata.get('bpm') |
| if bpm_value is not None and bpm_value != "N/A": |
| try: |
| bpm = int(bpm_value) if bpm_value else None |
| except: |
| bpm = None |
| else: |
| bpm = None |
| |
| key_scale = metadata.get('keyscale', '') |
| time_signature = metadata.get('timesignature', '') |
| |
| |
| duration_value = metadata.get('duration', -1) |
| if duration_value is not None and duration_value != "N/A": |
| try: |
| audio_duration = float(duration_value) |
| except: |
| audio_duration = -1 |
| else: |
| audio_duration = -1 |
| |
| batch_size = metadata.get('batch_size', 2) |
| inference_steps = metadata.get('inference_steps', 8) |
| guidance_scale = metadata.get('guidance_scale', 7.0) |
| seed = metadata.get('seed', '-1') |
| random_seed = False |
| use_adg = metadata.get('use_adg', False) |
| cfg_interval_start = metadata.get('cfg_interval_start', 0.0) |
| cfg_interval_end = metadata.get('cfg_interval_end', 1.0) |
| audio_format = metadata.get('audio_format', 'mp3') |
| lm_temperature = metadata.get('lm_temperature', 0.85) |
| lm_cfg_scale = metadata.get('lm_cfg_scale', 2.0) |
| lm_top_k = metadata.get('lm_top_k', 0) |
| lm_top_p = metadata.get('lm_top_p', 0.9) |
| lm_negative_prompt = metadata.get('lm_negative_prompt', 'NO USER INPUT') |
| use_cot_metas = metadata.get('use_cot_metas', True) |
| use_cot_caption = metadata.get('use_cot_caption', True) |
| use_cot_language = metadata.get('use_cot_language', True) |
| audio_cover_strength = metadata.get('audio_cover_strength', 1.0) |
| think = metadata.get('thinking', True) |
| audio_codes = metadata.get('audio_codes', '') |
| repainting_start = metadata.get('repainting_start', 0.0) |
| repainting_end = metadata.get('repainting_end', -1) |
| track_name = metadata.get('track_name') |
| complete_track_classes = metadata.get('complete_track_classes', []) |
| shift = metadata.get('shift', 3.0) |
| infer_method = metadata.get('infer_method', 'ode') |
| custom_timesteps = metadata.get('timesteps', '') |
| if custom_timesteps is None: |
| custom_timesteps = '' |
| instrumental = metadata.get('instrumental', False) |
| |
| gr.Info(t("messages.params_loaded", filename=os.path.basename(filepath))) |
| |
| return ( |
| task_type, captions, lyrics, vocal_language, bpm, key_scale, time_signature, |
| audio_duration, batch_size, inference_steps, guidance_scale, seed, random_seed, |
| use_adg, cfg_interval_start, cfg_interval_end, shift, infer_method, |
| custom_timesteps, |
| audio_format, lm_temperature, lm_cfg_scale, lm_top_k, lm_top_p, lm_negative_prompt, |
| use_cot_metas, use_cot_caption, use_cot_language, audio_cover_strength, |
| think, audio_codes, repainting_start, repainting_end, |
| track_name, complete_track_classes, instrumental, |
| True |
| ) |
| |
| except json.JSONDecodeError as e: |
| gr.Warning(t("messages.invalid_json", error=str(e))) |
| return [None] * 36 + [False] |
| except Exception as e: |
| gr.Warning(t("messages.load_error", error=str(e))) |
| return [None] * 36 + [False] |
|
|
|
|
| def load_random_example(task_type: str): |
| """Load a random example from the task-specific examples directory |
| |
| Args: |
| task_type: The task type (e.g., "text2music") |
| |
| Returns: |
| Tuple of (caption, lyrics, think, bpm, duration, keyscale, language, timesignature) for updating UI components |
| """ |
| try: |
| |
| current_file = os.path.abspath(__file__) |
| |
| project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(current_file)))) |
| |
| |
| examples_dir = os.path.join(project_root, "examples", task_type) |
| |
| |
| if not os.path.exists(examples_dir): |
| gr.Warning(f"Examples directory not found: examples/{task_type}/") |
| return "", "", True, None, None, "", "", "" |
| |
| |
| json_files = glob.glob(os.path.join(examples_dir, "*.json")) |
| |
| if not json_files: |
| gr.Warning(f"No JSON files found in examples/{task_type}/") |
| return "", "", True, None, None, "", "", "" |
| |
| |
| selected_file = random.choice(json_files) |
| |
| |
| try: |
| with open(selected_file, 'r', encoding='utf-8') as f: |
| data = json.load(f) |
| |
| |
| caption_value = data.get('caption', data.get('prompt', '')) |
| if not isinstance(caption_value, str): |
| caption_value = str(caption_value) if caption_value else '' |
| |
| |
| lyrics_value = data.get('lyrics', '') |
| if not isinstance(lyrics_value, str): |
| lyrics_value = str(lyrics_value) if lyrics_value else '' |
| |
| |
| think_value = data.get('think', True) |
| if not isinstance(think_value, bool): |
| think_value = True |
| |
| |
| bpm_value = None |
| if 'bpm' in data and data['bpm'] not in [None, "N/A", ""]: |
| try: |
| bpm_value = int(data['bpm']) |
| except (ValueError, TypeError): |
| pass |
| |
| duration_value = None |
| if 'duration' in data and data['duration'] not in [None, "N/A", ""]: |
| try: |
| duration_value = float(data['duration']) |
| except (ValueError, TypeError): |
| pass |
| |
| keyscale_value = data.get('keyscale', '') |
| if keyscale_value in [None, "N/A"]: |
| keyscale_value = '' |
| |
| language_value = data.get('language', '') |
| if language_value in [None, "N/A"]: |
| language_value = '' |
| |
| timesignature_value = data.get('timesignature', '') |
| if timesignature_value in [None, "N/A"]: |
| timesignature_value = '' |
| |
| gr.Info(t("messages.example_loaded", filename=os.path.basename(selected_file))) |
| return caption_value, lyrics_value, think_value, bpm_value, duration_value, keyscale_value, language_value, timesignature_value |
| |
| except json.JSONDecodeError as e: |
| gr.Warning(t("messages.example_failed", filename=os.path.basename(selected_file), error=str(e))) |
| return "", "", True, None, None, "", "", "" |
| except Exception as e: |
| gr.Warning(t("messages.example_error", error=str(e))) |
| return "", "", True, None, None, "", "", "" |
| |
| except Exception as e: |
| gr.Warning(t("messages.example_error", error=str(e))) |
| return "", "", True, None, None, "", "", "" |
|
|
|
|
| def sample_example_smart(llm_handler, task_type: str, constrained_decoding_debug: bool = False): |
| """Smart sample function that uses LM if initialized, otherwise falls back to examples |
| |
| This is a Gradio wrapper that uses the understand_music API from acestep.inference |
| to generate examples when LM is available. |
| |
| Args: |
| llm_handler: LLM handler instance |
| task_type: The task type (e.g., "text2music") |
| constrained_decoding_debug: Whether to enable debug logging for constrained decoding |
| |
| Returns: |
| Tuple of (caption, lyrics, think, bpm, duration, keyscale, language, timesignature) for updating UI components |
| """ |
| |
| if llm_handler.llm_initialized: |
| |
| try: |
| result = understand_music( |
| llm_handler=llm_handler, |
| audio_codes="NO USER INPUT", |
| temperature=0.85, |
| use_constrained_decoding=True, |
| constrained_decoding_debug=constrained_decoding_debug, |
| ) |
| |
| if result.success: |
| gr.Info(t("messages.lm_generated")) |
| return ( |
| result.caption, |
| result.lyrics, |
| True, |
| result.bpm, |
| result.duration, |
| result.keyscale, |
| result.language, |
| result.timesignature, |
| ) |
| else: |
| gr.Warning(t("messages.lm_fallback")) |
| return load_random_example(task_type) |
| |
| except Exception as e: |
| gr.Warning(t("messages.lm_fallback")) |
| return load_random_example(task_type) |
| else: |
| |
| return load_random_example(task_type) |
|
|
|
|
| def load_random_simple_description(): |
| """Load a random description from the simple_mode examples directory. |
| |
| Returns: |
| Tuple of (description, instrumental, vocal_language) for updating UI components |
| """ |
| try: |
| |
| current_file = os.path.abspath(__file__) |
| |
| project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(current_file)))) |
|
|
| |
| examples_dir = os.path.join(project_root, "examples", "simple_mode") |
|
|
| |
| if not os.path.exists(examples_dir): |
| gr.Warning(t("messages.simple_examples_not_found")) |
| return gr.update(), gr.update(), gr.update() |
|
|
| |
| json_files = glob.glob(os.path.join(examples_dir, "*.json")) |
|
|
| if not json_files: |
| gr.Warning(t("messages.simple_examples_empty")) |
| return gr.update(), gr.update(), gr.update() |
|
|
| |
| selected_file = random.choice(json_files) |
|
|
| |
| try: |
| with open(selected_file, 'r', encoding='utf-8') as f: |
| data = json.load(f) |
|
|
| |
| description = data.get('description', '') |
| instrumental = data.get('instrumental', False) |
| vocal_language = data.get('vocal_language', 'unknown') |
|
|
| |
| if isinstance(vocal_language, list): |
| vocal_language = vocal_language[0] if vocal_language else 'unknown' |
|
|
| gr.Info(t("messages.simple_example_loaded", filename=os.path.basename(selected_file))) |
| return description, instrumental, vocal_language |
| |
| except json.JSONDecodeError as e: |
| gr.Warning(t("messages.example_failed", filename=os.path.basename(selected_file), error=str(e))) |
| return gr.update(), gr.update(), gr.update() |
| except Exception as e: |
| gr.Warning(t("messages.example_error", error=str(e))) |
| return gr.update(), gr.update(), gr.update() |
| |
| except Exception as e: |
| gr.Warning(t("messages.example_error", error=str(e))) |
| return gr.update(), gr.update(), gr.update() |
|
|
|
|
| def refresh_checkpoints(dit_handler): |
| """Refresh available checkpoints""" |
| choices = dit_handler.get_available_checkpoints() |
| return gr.update(choices=choices) |
|
|
|
|
| def update_model_type_settings(config_path): |
| """Update UI settings based on model type (fallback when handler not initialized yet) |
| |
| Note: This is used as a fallback when the user changes config_path dropdown |
| before initializing the model. The actual settings are determined by the |
| handler's is_turbo_model() method after initialization. |
| """ |
| if config_path is None: |
| config_path = "" |
| config_path_lower = config_path.lower() |
| |
| |
| |
| if "turbo" in config_path_lower: |
| is_turbo = True |
| elif "base" in config_path_lower: |
| is_turbo = False |
| else: |
| |
| is_turbo = True |
| |
| return get_model_type_ui_settings(is_turbo) |
|
|
|
|
| def init_service_wrapper(dit_handler, llm_handler, checkpoint, config_path, device, init_llm, lm_model_path, backend, use_flash_attention, offload_to_cpu, offload_dit_to_cpu): |
| """Wrapper for service initialization, returns status, button state, accordion state, and model type settings""" |
| |
| status, enable = dit_handler.initialize_service( |
| checkpoint, config_path, device, |
| use_flash_attention=use_flash_attention, compile_model=False, |
| offload_to_cpu=offload_to_cpu, offload_dit_to_cpu=offload_dit_to_cpu |
| ) |
| |
| |
| if init_llm: |
| |
| current_file = os.path.abspath(__file__) |
| |
| project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(current_file)))) |
| checkpoint_dir = os.path.join(project_root, "checkpoints") |
| |
| lm_status, lm_success = llm_handler.initialize( |
| checkpoint_dir=checkpoint_dir, |
| lm_model_path=lm_model_path, |
| backend=backend, |
| device=device, |
| offload_to_cpu=offload_to_cpu, |
| dtype=dit_handler.dtype |
| ) |
| |
| if lm_success: |
| status += f"\n{lm_status}" |
| else: |
| status += f"\n{lm_status}" |
| |
| |
| |
| |
| is_model_initialized = dit_handler.model is not None |
| accordion_state = gr.Accordion(open=not is_model_initialized) |
| |
| |
| is_turbo = dit_handler.is_turbo_model() |
| model_type_settings = get_model_type_ui_settings(is_turbo) |
| |
| return ( |
| status, |
| gr.update(interactive=enable), |
| accordion_state, |
| *model_type_settings |
| ) |
|
|
|
|
| def get_model_type_ui_settings(is_turbo: bool): |
| """Get UI settings based on whether the model is turbo or base""" |
| if is_turbo: |
| |
| return ( |
| gr.update(value=8, maximum=20, minimum=1), |
| gr.update(visible=False), |
| gr.update(visible=False), |
| gr.update(value=3.0, visible=True), |
| gr.update(visible=False), |
| gr.update(visible=False), |
| gr.update(choices=TASK_TYPES_TURBO), |
| ) |
| else: |
| |
| return ( |
| gr.update(value=32, maximum=200, minimum=1), |
| gr.update(visible=True), |
| gr.update(visible=True), |
| gr.update(value=3.0, visible=True), |
| gr.update(visible=True), |
| gr.update(visible=True), |
| gr.update(choices=TASK_TYPES_BASE), |
| ) |
|
|
|
|
| def update_negative_prompt_visibility(init_llm_checked): |
| """Update negative prompt visibility: show if Initialize 5Hz LM checkbox is checked""" |
| return gr.update(visible=init_llm_checked) |
|
|
|
|
| def update_audio_cover_strength_visibility(task_type_value, init_llm_checked): |
| """Update audio_cover_strength visibility and label""" |
| |
| is_visible = (task_type_value == "cover") or init_llm_checked |
| |
| if init_llm_checked and task_type_value != "cover": |
| label = "LM codes strength" |
| info = "Control how many denoising steps use LM-generated codes" |
| else: |
| label = "Audio Cover Strength" |
| info = "Control how many denoising steps use cover mode" |
| |
| return gr.update(visible=is_visible, label=label, info=info) |
|
|
|
|
| def convert_src_audio_to_codes_wrapper(dit_handler, src_audio): |
| """Wrapper for converting src audio to codes""" |
| codes_string = dit_handler.convert_src_audio_to_codes(src_audio) |
| return codes_string |
|
|
|
|
| def update_instruction_ui( |
| dit_handler, |
| task_type_value: str, |
| track_name_value: Optional[str], |
| complete_track_classes_value: list, |
| audio_codes_content: str = "", |
| init_llm_checked: bool = False |
| ) -> tuple: |
| """Update instruction and UI visibility based on task type.""" |
| instruction = dit_handler.generate_instruction( |
| task_type=task_type_value, |
| track_name=track_name_value, |
| complete_track_classes=complete_track_classes_value |
| ) |
| |
| |
| track_name_visible = task_type_value in ["lego", "extract"] |
| |
| complete_visible = task_type_value == "complete" |
| |
| audio_cover_strength_visible = (task_type_value == "cover") or init_llm_checked |
| |
| if init_llm_checked and task_type_value != "cover": |
| audio_cover_strength_label = "LM codes strength" |
| audio_cover_strength_info = "Control how many denoising steps use LM-generated codes" |
| else: |
| audio_cover_strength_label = "Audio Cover Strength" |
| audio_cover_strength_info = "Control how many denoising steps use cover mode" |
| |
| repainting_visible = task_type_value in ["repaint", "lego"] |
| |
| |
| has_audio_codes = audio_codes_content and str(audio_codes_content).strip() |
| text2music_audio_codes_visible = task_type_value == "text2music" or has_audio_codes |
| |
| return ( |
| instruction, |
| gr.update(visible=track_name_visible), |
| gr.update(visible=complete_visible), |
| gr.update(visible=audio_cover_strength_visible, label=audio_cover_strength_label, info=audio_cover_strength_info), |
| gr.update(visible=repainting_visible), |
| gr.update(visible=text2music_audio_codes_visible), |
| ) |
|
|
|
|
| def transcribe_audio_codes(llm_handler, audio_code_string, constrained_decoding_debug): |
| """ |
| Transcribe audio codes to metadata using LLM understanding. |
| If audio_code_string is empty, generate a sample example instead. |
| |
| This is a Gradio wrapper around the understand_music API in acestep.inference. |
| |
| Args: |
| llm_handler: LLM handler instance |
| audio_code_string: String containing audio codes (or empty for example generation) |
| constrained_decoding_debug: Whether to enable debug logging for constrained decoding |
| |
| Returns: |
| Tuple of (status_message, caption, lyrics, bpm, duration, keyscale, language, timesignature, is_format_caption) |
| """ |
| |
| result = understand_music( |
| llm_handler=llm_handler, |
| audio_codes=audio_code_string, |
| use_constrained_decoding=True, |
| constrained_decoding_debug=constrained_decoding_debug, |
| ) |
| |
| |
| if not result.success: |
| |
| if result.error == "LLM not initialized": |
| return t("messages.lm_not_initialized"), "", "", None, None, "", "", "", False |
| return result.status_message, "", "", None, None, "", "", "", False |
| |
| return ( |
| result.status_message, |
| result.caption, |
| result.lyrics, |
| result.bpm, |
| result.duration, |
| result.keyscale, |
| result.language, |
| result.timesignature, |
| True |
| ) |
|
|
|
|
| def update_transcribe_button_text(audio_code_string): |
| """ |
| Update the transcribe button text based on input content. |
| If empty: "Generate Example" |
| If has content: "Transcribe" |
| """ |
| if not audio_code_string or not audio_code_string.strip(): |
| return gr.update(value="Generate Example") |
| else: |
| return gr.update(value="Transcribe") |
|
|
|
|
| def reset_format_caption_flag(): |
| """Reset is_format_caption to False when user manually edits caption/metadata""" |
| return False |
|
|
|
|
| def update_audio_uploads_accordion(reference_audio, src_audio): |
| """Update Audio Uploads accordion open state based on whether audio files are present""" |
| has_audio = (reference_audio is not None) or (src_audio is not None) |
| return gr.Accordion(open=has_audio) |
|
|
|
|
| def handle_instrumental_checkbox(instrumental_checked, current_lyrics): |
| """ |
| Handle instrumental checkbox changes. |
| When checked: if no lyrics, fill with [Instrumental] |
| When unchecked: if lyrics is [Instrumental], clear it |
| """ |
| if instrumental_checked: |
| |
| if not current_lyrics or not current_lyrics.strip(): |
| return "[Instrumental]" |
| else: |
| |
| return current_lyrics |
| else: |
| |
| if current_lyrics and current_lyrics.strip() == "[Instrumental]": |
| return "" |
| else: |
| |
| return current_lyrics |
|
|
|
|
| def handle_simple_instrumental_change(is_instrumental: bool): |
| """ |
| Handle simple mode instrumental checkbox changes. |
| When checked: set vocal_language to "unknown" and disable editing. |
| When unchecked: enable vocal_language editing. |
| |
| Args: |
| is_instrumental: Whether instrumental checkbox is checked |
| |
| Returns: |
| gr.update for simple_vocal_language dropdown |
| """ |
| if is_instrumental: |
| return gr.update(value="unknown", interactive=False) |
| else: |
| return gr.update(interactive=True) |
|
|
|
|
| def update_audio_components_visibility(batch_size): |
| """Show/hide individual audio components based on batch size (1-8) |
| |
| Row 1: Components 1-4 (batch_size 1-4) |
| Row 2: Components 5-8 (batch_size 5-8) |
| """ |
| |
| batch_size = min(max(int(batch_size), 1), 8) |
| |
| |
| updates_row1 = ( |
| gr.update(visible=True), |
| gr.update(visible=batch_size >= 2), |
| gr.update(visible=batch_size >= 3), |
| gr.update(visible=batch_size >= 4), |
| ) |
| |
| |
| show_row_5_8 = batch_size >= 5 |
| updates_row2 = ( |
| gr.update(visible=show_row_5_8), |
| gr.update(visible=batch_size >= 5), |
| gr.update(visible=batch_size >= 6), |
| gr.update(visible=batch_size >= 7), |
| gr.update(visible=batch_size >= 8), |
| ) |
| |
| return updates_row1 + updates_row2 |
|
|
|
|
| def handle_generation_mode_change(mode: str): |
| """ |
| Handle generation mode change between Simple and Custom modes. |
| |
| In Simple mode: |
| - Show simple mode group (query input, instrumental checkbox, create button) |
| - Collapse caption and lyrics accordions |
| - Hide optional parameters accordion |
| - Disable generate button until sample is created |
| |
| In Custom mode: |
| - Hide simple mode group |
| - Expand caption and lyrics accordions |
| - Show optional parameters accordion |
| - Enable generate button |
| |
| Args: |
| mode: "simple" or "custom" |
| |
| Returns: |
| Tuple of updates for: |
| - simple_mode_group (visibility) |
| - caption_accordion (open state) |
| - lyrics_accordion (open state) |
| - generate_btn (interactive state) |
| - simple_sample_created (reset state) |
| - optional_params_accordion (visibility) |
| """ |
| is_simple = mode == "simple" |
| |
| return ( |
| gr.update(visible=is_simple), |
| gr.Accordion(open=not is_simple), |
| gr.Accordion(open=not is_simple), |
| gr.update(interactive=not is_simple), |
| False, |
| gr.Accordion(open=not is_simple), |
| ) |
|
|
|
|
| def handle_create_sample( |
| llm_handler, |
| query: str, |
| instrumental: bool, |
| vocal_language: str, |
| lm_temperature: float, |
| lm_top_k: int, |
| lm_top_p: float, |
| constrained_decoding_debug: bool = False, |
| ): |
| """ |
| Handle the Create Sample button click in Simple mode. |
| |
| Creates a sample from the user's query using the LLM, then populates |
| the caption, lyrics, and metadata fields. |
| |
| Note: cfg_scale and negative_prompt are not supported in create_sample mode. |
| |
| Args: |
| llm_handler: LLM handler instance |
| query: User's natural language music description |
| instrumental: Whether to generate instrumental music |
| vocal_language: Preferred vocal language for constrained decoding |
| lm_temperature: LLM temperature for generation |
| lm_top_k: LLM top-k sampling |
| lm_top_p: LLM top-p sampling |
| constrained_decoding_debug: Whether to enable debug logging |
| |
| Returns: |
| Tuple of updates for: |
| - captions |
| - lyrics |
| - bpm |
| - audio_duration |
| - key_scale |
| - vocal_language |
| - time_signature |
| - instrumental_checkbox |
| - caption_accordion (open) |
| - lyrics_accordion (open) |
| - generate_btn (interactive) |
| - simple_sample_created (True) |
| - think_checkbox (True) |
| - is_format_caption_state (True) |
| - status_output |
| """ |
| |
| if not llm_handler.llm_initialized: |
| gr.Warning(t("messages.lm_not_initialized")) |
| return ( |
| gr.update(), |
| gr.update(), |
| gr.update(), |
| gr.update(), |
| gr.update(), |
| gr.update(), |
| gr.update(), |
| gr.update(), |
| gr.update(), |
| gr.update(), |
| gr.update(interactive=False), |
| False, |
| gr.update(), |
| gr.update(), |
| t("messages.lm_not_initialized"), |
| ) |
| |
| |
| top_k_value = None if not lm_top_k or lm_top_k == 0 else int(lm_top_k) |
| top_p_value = None if not lm_top_p or lm_top_p >= 1.0 else lm_top_p |
| |
| |
| |
| result = create_sample( |
| llm_handler=llm_handler, |
| query=query, |
| instrumental=instrumental, |
| vocal_language=vocal_language, |
| temperature=lm_temperature, |
| top_k=top_k_value, |
| top_p=top_p_value, |
| use_constrained_decoding=True, |
| constrained_decoding_debug=constrained_decoding_debug, |
| ) |
| |
| |
| if not result.success: |
| gr.Warning(result.status_message or t("messages.sample_creation_failed")) |
| return ( |
| gr.update(), |
| gr.update(), |
| gr.update(), |
| gr.update(), |
| gr.update(), |
| gr.update(), |
| gr.update(), |
| gr.update(), |
| gr.update(), |
| gr.update(), |
| gr.update(), |
| gr.update(interactive=False), |
| False, |
| gr.update(), |
| gr.update(), |
| result.status_message or t("messages.sample_creation_failed"), |
| ) |
| |
| |
| gr.Info(t("messages.sample_created")) |
| |
| return ( |
| result.caption, |
| result.lyrics, |
| result.bpm, |
| result.duration if result.duration and result.duration > 0 else -1, |
| result.keyscale, |
| result.language, |
| result.language, |
| result.timesignature, |
| result.instrumental, |
| gr.Accordion(open=True), |
| gr.Accordion(open=True), |
| gr.update(interactive=True), |
| True, |
| True, |
| True, |
| result.status_message, |
| ) |
|
|
|
|
| def handle_format_sample( |
| llm_handler, |
| caption: str, |
| lyrics: str, |
| bpm, |
| audio_duration, |
| key_scale: str, |
| time_signature: str, |
| lm_temperature: float, |
| lm_top_k: int, |
| lm_top_p: float, |
| constrained_decoding_debug: bool = False, |
| ): |
| """ |
| Handle the Format button click to format caption and lyrics. |
| |
| Takes user-provided caption and lyrics, and uses the LLM to generate |
| structured music metadata and an enhanced description. |
| |
| Note: cfg_scale and negative_prompt are not supported in format mode. |
| |
| Args: |
| llm_handler: LLM handler instance |
| caption: User's caption/description |
| lyrics: User's lyrics |
| bpm: User-provided BPM (optional, for constrained decoding) |
| audio_duration: User-provided duration (optional, for constrained decoding) |
| key_scale: User-provided key scale (optional, for constrained decoding) |
| time_signature: User-provided time signature (optional, for constrained decoding) |
| lm_temperature: LLM temperature for generation |
| lm_top_k: LLM top-k sampling |
| lm_top_p: LLM top-p sampling |
| constrained_decoding_debug: Whether to enable debug logging |
| |
| Returns: |
| Tuple of updates for: |
| - captions |
| - lyrics |
| - bpm |
| - audio_duration |
| - key_scale |
| - vocal_language |
| - time_signature |
| - is_format_caption_state |
| - status_output |
| """ |
| |
| if not llm_handler.llm_initialized: |
| gr.Warning(t("messages.lm_not_initialized")) |
| return ( |
| gr.update(), |
| gr.update(), |
| gr.update(), |
| gr.update(), |
| gr.update(), |
| gr.update(), |
| gr.update(), |
| gr.update(), |
| t("messages.lm_not_initialized"), |
| ) |
| |
| |
| user_metadata = {} |
| if bpm is not None and bpm > 0: |
| user_metadata['bpm'] = int(bpm) |
| if audio_duration is not None and float(audio_duration) > 0: |
| user_metadata['duration'] = int(audio_duration) |
| if key_scale and key_scale.strip(): |
| user_metadata['keyscale'] = key_scale.strip() |
| if time_signature and time_signature.strip(): |
| user_metadata['timesignature'] = time_signature.strip() |
| |
| |
| user_metadata_to_pass = user_metadata if user_metadata else None |
| |
| |
| top_k_value = None if not lm_top_k or lm_top_k == 0 else int(lm_top_k) |
| top_p_value = None if not lm_top_p or lm_top_p >= 1.0 else lm_top_p |
| |
| |
| result = format_sample( |
| llm_handler=llm_handler, |
| caption=caption, |
| lyrics=lyrics, |
| user_metadata=user_metadata_to_pass, |
| temperature=lm_temperature, |
| top_k=top_k_value, |
| top_p=top_p_value, |
| use_constrained_decoding=True, |
| constrained_decoding_debug=constrained_decoding_debug, |
| ) |
| |
| |
| if not result.success: |
| gr.Warning(result.status_message or t("messages.format_failed")) |
| return ( |
| gr.update(), |
| gr.update(), |
| gr.update(), |
| gr.update(), |
| gr.update(), |
| gr.update(), |
| gr.update(), |
| gr.update(), |
| result.status_message or t("messages.format_failed"), |
| ) |
| |
| |
| gr.Info(t("messages.format_success")) |
| |
| return ( |
| result.caption, |
| result.lyrics, |
| result.bpm, |
| result.duration if result.duration and result.duration > 0 else -1, |
| result.keyscale, |
| result.language, |
| result.timesignature, |
| True, |
| result.status_message, |
| ) |
|
|
|
|