""" Gradio UI Generation Section Module Contains generation section component definitions """ import gradio as gr from acestep.constants import ( VALID_LANGUAGES, TRACK_NAMES, TASK_TYPES_TURBO, TASK_TYPES_BASE, DEFAULT_DIT_INSTRUCTION, ) from acestep.gradio_ui.i18n import t def create_generation_section(dit_handler, llm_handler, init_params=None, language='en') -> dict: """Create generation section Args: dit_handler: DiT handler instance llm_handler: LM handler instance init_params: Dictionary containing initialization parameters and state. If None, service will not be pre-initialized. language: UI language code ('en', 'zh', 'ja') """ # Check if service is pre-initialized service_pre_initialized = init_params is not None and init_params.get('pre_initialized', False) # Check if running in service mode (restricted UI) service_mode = init_params is not None and init_params.get('service_mode', False) # Get current language from init_params if available current_language = init_params.get('language', language) if init_params else language with gr.Group(): # Service Configuration - collapse if pre-initialized, hide if in service mode accordion_open = not service_pre_initialized accordion_visible = not service_pre_initialized # Hide when running in service mode with gr.Accordion(t("service.title"), open=accordion_open, visible=accordion_visible) as service_config_accordion: # Language selector at the top with gr.Row(): language_dropdown = gr.Dropdown( choices=[ ("English", "en"), ("中文", "zh"), ("日本語", "ja"), ], value=current_language, label=t("service.language_label"), info=t("service.language_info"), scale=1, ) # Dropdown options section - all dropdowns grouped together with gr.Row(equal_height=True): with gr.Column(scale=4): # Set checkpoint value from init_params if pre-initialized checkpoint_value = init_params.get('checkpoint') if service_pre_initialized else None checkpoint_dropdown = gr.Dropdown( label=t("service.checkpoint_label"), choices=dit_handler.get_available_checkpoints(), value=checkpoint_value, info=t("service.checkpoint_info") ) with gr.Column(scale=1, min_width=90): refresh_btn = gr.Button(t("service.refresh_btn"), size="sm") with gr.Row(): # Get available acestep-v15- model list available_models = dit_handler.get_available_acestep_v15_models() default_model = "acestep-v15-turbo" if "acestep-v15-turbo" in available_models else (available_models[0] if available_models else None) # Set config_path value from init_params if pre-initialized config_path_value = init_params.get('config_path', default_model) if service_pre_initialized else default_model config_path = gr.Dropdown( label=t("service.model_path_label"), choices=available_models, value=config_path_value, info=t("service.model_path_info") ) # Set device value from init_params if pre-initialized device_value = init_params.get('device', 'auto') if service_pre_initialized else 'auto' device = gr.Dropdown( choices=["auto", "cuda", "cpu"], value=device_value, label=t("service.device_label"), info=t("service.device_info") ) with gr.Row(): # Get available 5Hz LM model list available_lm_models = llm_handler.get_available_5hz_lm_models() default_lm_model = "acestep-5Hz-lm-0.6B" if "acestep-5Hz-lm-0.6B" in available_lm_models else (available_lm_models[0] if available_lm_models else None) # Set lm_model_path value from init_params if pre-initialized lm_model_path_value = init_params.get('lm_model_path', default_lm_model) if service_pre_initialized else default_lm_model lm_model_path = gr.Dropdown( label=t("service.lm_model_path_label"), choices=available_lm_models, value=lm_model_path_value, info=t("service.lm_model_path_info") ) # Set backend value from init_params if pre-initialized backend_value = init_params.get('backend', 'vllm') if service_pre_initialized else 'vllm' backend_dropdown = gr.Dropdown( choices=["vllm", "pt"], value=backend_value, label=t("service.backend_label"), info=t("service.backend_info") ) # Checkbox options section - all checkboxes grouped together with gr.Row(): # Set init_llm value from init_params if pre-initialized init_llm_value = init_params.get('init_llm', True) if service_pre_initialized else True init_llm_checkbox = gr.Checkbox( label=t("service.init_llm_label"), value=init_llm_value, info=t("service.init_llm_info"), ) # Auto-detect flash attention availability flash_attn_available = dit_handler.is_flash_attention_available() # Set use_flash_attention value from init_params if pre-initialized use_flash_attention_value = init_params.get('use_flash_attention', flash_attn_available) if service_pre_initialized else flash_attn_available use_flash_attention_checkbox = gr.Checkbox( label=t("service.flash_attention_label"), value=use_flash_attention_value, interactive=flash_attn_available, info=t("service.flash_attention_info_enabled") if flash_attn_available else t("service.flash_attention_info_disabled") ) # Set offload_to_cpu value from init_params if pre-initialized offload_to_cpu_value = init_params.get('offload_to_cpu', False) if service_pre_initialized else False offload_to_cpu_checkbox = gr.Checkbox( label=t("service.offload_cpu_label"), value=offload_to_cpu_value, info=t("service.offload_cpu_info") ) # Set offload_dit_to_cpu value from init_params if pre-initialized offload_dit_to_cpu_value = init_params.get('offload_dit_to_cpu', False) if service_pre_initialized else False offload_dit_to_cpu_checkbox = gr.Checkbox( label=t("service.offload_dit_cpu_label"), value=offload_dit_to_cpu_value, info=t("service.offload_dit_cpu_info") ) init_btn = gr.Button(t("service.init_btn"), variant="primary", size="lg") # Set init_status value from init_params if pre-initialized init_status_value = init_params.get('init_status', '') if service_pre_initialized else '' init_status = gr.Textbox(label=t("service.status_label"), interactive=False, lines=3, value=init_status_value) # LoRA Configuration Section gr.HTML("

🔧 LoRA Adapter

") with gr.Row(): lora_path = gr.Textbox( label="LoRA Path", placeholder="./lora_output/final/adapter", info="Path to trained LoRA adapter directory", scale=3, ) load_lora_btn = gr.Button("📥 Load LoRA", variant="secondary", scale=1) unload_lora_btn = gr.Button("🗑️ Unload", variant="secondary", scale=1) with gr.Row(): use_lora_checkbox = gr.Checkbox( label="Use LoRA", value=False, info="Enable LoRA adapter for inference", scale=1, ) lora_status = gr.Textbox( label="LoRA Status", value="No LoRA loaded", interactive=False, scale=2, ) # Inputs with gr.Row(): with gr.Column(scale=2): with gr.Accordion(t("generation.required_inputs"), open=True): # Task type # Determine initial task_type choices based on actual model in use # When service is pre-initialized, use config_path from init_params actual_model = init_params.get('config_path', default_model) if service_pre_initialized else default_model actual_model_lower = (actual_model or "").lower() if "turbo" in actual_model_lower: initial_task_choices = TASK_TYPES_TURBO else: initial_task_choices = TASK_TYPES_BASE with gr.Row(equal_height=True): with gr.Column(scale=2): task_type = gr.Dropdown( choices=initial_task_choices, value="text2music", label=t("generation.task_type_label"), info=t("generation.task_type_info"), ) with gr.Column(scale=7): instruction_display_gen = gr.Textbox( label=t("generation.instruction_label"), value=DEFAULT_DIT_INSTRUCTION, interactive=False, lines=1, info=t("generation.instruction_info"), ) with gr.Column(scale=1, min_width=100): load_file = gr.UploadButton( t("generation.load_btn"), file_types=[".json"], file_count="single", variant="secondary", size="sm", ) track_name = gr.Dropdown( choices=TRACK_NAMES, value=None, label=t("generation.track_name_label"), info=t("generation.track_name_info"), visible=False ) complete_track_classes = gr.CheckboxGroup( choices=TRACK_NAMES, label=t("generation.track_classes_label"), info=t("generation.track_classes_info"), visible=False ) # Audio uploads audio_uploads_accordion = gr.Accordion(t("generation.audio_uploads"), open=False) with audio_uploads_accordion: with gr.Row(equal_height=True): with gr.Column(scale=2): reference_audio = gr.Audio( label=t("generation.reference_audio"), type="filepath", ) with gr.Column(scale=7): src_audio = gr.Audio( label=t("generation.source_audio"), type="filepath", ) with gr.Column(scale=1, min_width=80): convert_src_to_codes_btn = gr.Button( t("generation.convert_codes_btn"), variant="secondary", size="sm" ) # Audio Codes for text2music - single input for transcription or cover task with gr.Accordion(t("generation.lm_codes_hints"), open=False, visible=True) as text2music_audio_codes_group: with gr.Row(equal_height=True): text2music_audio_code_string = gr.Textbox( label=t("generation.lm_codes_label"), placeholder=t("generation.lm_codes_placeholder"), lines=6, info=t("generation.lm_codes_info"), scale=9, ) transcribe_btn = gr.Button( t("generation.transcribe_btn"), variant="secondary", size="sm", scale=1, ) # Repainting controls with gr.Group(visible=False) as repainting_group: gr.HTML(f"
{t('generation.repainting_controls')}
") with gr.Row(): repainting_start = gr.Number( label=t("generation.repainting_start"), value=0.0, step=0.1, ) repainting_end = gr.Number( label=t("generation.repainting_end"), value=-1, minimum=-1, step=0.1, ) # Simple/Custom Mode Toggle # In service mode: only Custom mode, hide the toggle with gr.Row(visible=not service_mode): generation_mode = gr.Radio( choices=[ (t("generation.mode_simple"), "simple"), (t("generation.mode_custom"), "custom"), ], value="custom" if service_mode else "simple", label=t("generation.mode_label"), info=t("generation.mode_info"), ) # Simple Mode Components - hidden in service mode with gr.Group(visible=not service_mode) as simple_mode_group: with gr.Row(equal_height=True): simple_query_input = gr.Textbox( label=t("generation.simple_query_label"), placeholder=t("generation.simple_query_placeholder"), lines=2, info=t("generation.simple_query_info"), scale=12, ) with gr.Column(scale=1, min_width=100): random_desc_btn = gr.Button( "🎲", variant="secondary", size="sm", scale=2 ) with gr.Row(equal_height=True): with gr.Column(scale=1, variant="compact"): simple_instrumental_checkbox = gr.Checkbox( label=t("generation.instrumental_label"), value=False, ) with gr.Column(scale=18): create_sample_btn = gr.Button( t("generation.create_sample_btn"), variant="primary", size="lg", ) with gr.Column(scale=1, variant="compact"): simple_vocal_language = gr.Dropdown( choices=VALID_LANGUAGES, value="unknown", allow_custom_value=True, label=t("generation.simple_vocal_language_label"), interactive=True, ) # State to track if sample has been created in Simple mode simple_sample_created = gr.State(value=False) # Music Caption - wrapped in accordion that can be collapsed in Simple mode # In service mode: auto-expand with gr.Accordion(t("generation.caption_title"), open=service_mode) as caption_accordion: with gr.Row(equal_height=True): captions = gr.Textbox( label=t("generation.caption_label"), placeholder=t("generation.caption_placeholder"), lines=3, info=t("generation.caption_info"), scale=12, ) with gr.Column(scale=1, min_width=100): sample_btn = gr.Button( "🎲", variant="secondary", size="sm", scale=2, ) # Lyrics - wrapped in accordion that can be collapsed in Simple mode # In service mode: auto-expand with gr.Accordion(t("generation.lyrics_title"), open=service_mode) as lyrics_accordion: lyrics = gr.Textbox( label=t("generation.lyrics_label"), placeholder=t("generation.lyrics_placeholder"), lines=8, info=t("generation.lyrics_info") ) with gr.Row(variant="compact", equal_height=True): instrumental_checkbox = gr.Checkbox( label=t("generation.instrumental_label"), value=False, scale=1, min_width=120, container=True, ) # 中间:语言选择 (Dropdown) # 移除 gr.HTML hack,直接使用 label 参数,Gradio 会自动处理对齐 vocal_language = gr.Dropdown( choices=VALID_LANGUAGES, value="unknown", label=t("generation.vocal_language_label"), show_label=False, container=True, allow_custom_value=True, scale=3, ) # 右侧:格式化按钮 (Button) # 放在同一行最右侧,操作更顺手 format_btn = gr.Button( t("generation.format_btn"), variant="secondary", scale=1, min_width=80, ) # Optional Parameters # In service mode: auto-expand with gr.Accordion(t("generation.optional_params"), open=service_mode) as optional_params_accordion: with gr.Row(): bpm = gr.Number( label=t("generation.bpm_label"), value=None, step=1, info=t("generation.bpm_info") ) key_scale = gr.Textbox( label=t("generation.keyscale_label"), placeholder=t("generation.keyscale_placeholder"), value="", info=t("generation.keyscale_info") ) time_signature = gr.Dropdown( choices=["2", "3", "4", "N/A", ""], value="", label=t("generation.timesig_label"), allow_custom_value=True, info=t("generation.timesig_info") ) audio_duration = gr.Number( label=t("generation.duration_label"), value=-1, minimum=-1, maximum=600.0, step=0.1, info=t("generation.duration_info") ) batch_size_input = gr.Number( label=t("generation.batch_size_label"), value=2, minimum=1, maximum=8, step=1, info=t("generation.batch_size_info"), interactive=not service_mode # Fixed in service mode ) # Advanced Settings # Default UI settings use turbo mode (max 20 steps, default 8, show shift with default 3) # These will be updated after model initialization based on handler.is_turbo_model() with gr.Accordion(t("generation.advanced_settings"), open=False): with gr.Row(): inference_steps = gr.Slider( minimum=1, maximum=20, value=8, step=1, label=t("generation.inference_steps_label"), info=t("generation.inference_steps_info") ) guidance_scale = gr.Slider( minimum=1.0, maximum=15.0, value=7.0, step=0.1, label=t("generation.guidance_scale_label"), info=t("generation.guidance_scale_info"), visible=False ) with gr.Column(): seed = gr.Textbox( label=t("generation.seed_label"), value="-1", info=t("generation.seed_info") ) random_seed_checkbox = gr.Checkbox( label=t("generation.random_seed_label"), value=True, info=t("generation.random_seed_info") ) audio_format = gr.Dropdown( choices=["mp3", "flac"], value="mp3", label=t("generation.audio_format_label"), info=t("generation.audio_format_info"), interactive=not service_mode # Fixed in service mode ) with gr.Row(): use_adg = gr.Checkbox( label=t("generation.use_adg_label"), value=False, info=t("generation.use_adg_info"), visible=False ) shift = gr.Slider( minimum=1.0, maximum=5.0, value=3.0, step=0.1, label=t("generation.shift_label"), info=t("generation.shift_info"), visible=True ) infer_method = gr.Dropdown( choices=["ode", "sde"], value="ode", label=t("generation.infer_method_label"), info=t("generation.infer_method_info"), ) with gr.Row(): custom_timesteps = gr.Textbox( label=t("generation.custom_timesteps_label"), placeholder="0.97,0.76,0.615,0.5,0.395,0.28,0.18,0.085,0", value="", info=t("generation.custom_timesteps_info"), ) with gr.Row(): cfg_interval_start = gr.Slider( minimum=0.0, maximum=1.0, value=0.0, step=0.01, label=t("generation.cfg_interval_start"), visible=False ) cfg_interval_end = gr.Slider( minimum=0.0, maximum=1.0, value=1.0, step=0.01, label=t("generation.cfg_interval_end"), visible=False ) # LM (Language Model) Parameters gr.HTML(f"

{t('generation.lm_params_title')}

") with gr.Row(): lm_temperature = gr.Slider( label=t("generation.lm_temperature_label"), minimum=0.0, maximum=2.0, value=0.85, step=0.1, scale=1, info=t("generation.lm_temperature_info") ) lm_cfg_scale = gr.Slider( label=t("generation.lm_cfg_scale_label"), minimum=1.0, maximum=3.0, value=2.0, step=0.1, scale=1, info=t("generation.lm_cfg_scale_info") ) lm_top_k = gr.Slider( label=t("generation.lm_top_k_label"), minimum=0, maximum=100, value=0, step=1, scale=1, info=t("generation.lm_top_k_info") ) lm_top_p = gr.Slider( label=t("generation.lm_top_p_label"), minimum=0.0, maximum=1.0, value=0.9, step=0.01, scale=1, info=t("generation.lm_top_p_info") ) with gr.Row(): lm_negative_prompt = gr.Textbox( label=t("generation.lm_negative_prompt_label"), value="NO USER INPUT", placeholder=t("generation.lm_negative_prompt_placeholder"), info=t("generation.lm_negative_prompt_info"), lines=2, scale=2, ) with gr.Row(): use_cot_metas = gr.Checkbox( label=t("generation.cot_metas_label"), value=True, info=t("generation.cot_metas_info"), scale=1, ) use_cot_language = gr.Checkbox( label=t("generation.cot_language_label"), value=True, info=t("generation.cot_language_info"), scale=1, ) constrained_decoding_debug = gr.Checkbox( label=t("generation.constrained_debug_label"), value=False, info=t("generation.constrained_debug_info"), scale=1, interactive=not service_mode # Fixed in service mode ) with gr.Row(): auto_score = gr.Checkbox( label=t("generation.auto_score_label"), value=False, info=t("generation.auto_score_info"), scale=1, interactive=not service_mode # Fixed in service mode ) auto_lrc = gr.Checkbox( label=t("generation.auto_lrc_label"), value=False, info=t("generation.auto_lrc_info"), scale=1, interactive=not service_mode # Fixed in service mode ) lm_batch_chunk_size = gr.Number( label=t("generation.lm_batch_chunk_label"), value=8, minimum=1, maximum=32, step=1, info=t("generation.lm_batch_chunk_info"), scale=1, interactive=not service_mode # Fixed in service mode ) with gr.Row(): audio_cover_strength = gr.Slider( minimum=0.0, maximum=1.0, value=1.0, step=0.01, label=t("generation.codes_strength_label"), info=t("generation.codes_strength_info"), scale=1, ) score_scale = gr.Slider( minimum=0.01, maximum=1.0, value=0.5, step=0.01, label=t("generation.score_sensitivity_label"), info=t("generation.score_sensitivity_info"), scale=1, visible=not service_mode # Hidden in service mode ) # Set generate_btn to interactive if service is pre-initialized generate_btn_interactive = init_params.get('enable_generate', False) if service_pre_initialized else False with gr.Row(equal_height=True): with gr.Column(scale=1, variant="compact"): think_checkbox = gr.Checkbox( label=t("generation.think_label"), value=True, scale=1, ) allow_lm_batch = gr.Checkbox( label=t("generation.parallel_thinking_label"), value=True, scale=1, ) with gr.Column(scale=18): generate_btn = gr.Button(t("generation.generate_btn"), variant="primary", size="lg", interactive=generate_btn_interactive) with gr.Column(scale=1, variant="compact"): autogen_checkbox = gr.Checkbox( label=t("generation.autogen_label"), value=False, # Default to False for both service and local modes scale=1, interactive=not service_mode # Not selectable in service mode ) use_cot_caption = gr.Checkbox( label=t("generation.caption_rewrite_label"), value=True, scale=1, ) return { "service_config_accordion": service_config_accordion, "language_dropdown": language_dropdown, "checkpoint_dropdown": checkpoint_dropdown, "refresh_btn": refresh_btn, "config_path": config_path, "device": device, "init_btn": init_btn, "init_status": init_status, "lm_model_path": lm_model_path, "init_llm_checkbox": init_llm_checkbox, "backend_dropdown": backend_dropdown, "use_flash_attention_checkbox": use_flash_attention_checkbox, "offload_to_cpu_checkbox": offload_to_cpu_checkbox, "offload_dit_to_cpu_checkbox": offload_dit_to_cpu_checkbox, # LoRA components "lora_path": lora_path, "load_lora_btn": load_lora_btn, "unload_lora_btn": unload_lora_btn, "use_lora_checkbox": use_lora_checkbox, "lora_status": lora_status, "task_type": task_type, "instruction_display_gen": instruction_display_gen, "track_name": track_name, "complete_track_classes": complete_track_classes, "audio_uploads_accordion": audio_uploads_accordion, "reference_audio": reference_audio, "src_audio": src_audio, "convert_src_to_codes_btn": convert_src_to_codes_btn, "text2music_audio_code_string": text2music_audio_code_string, "transcribe_btn": transcribe_btn, "text2music_audio_codes_group": text2music_audio_codes_group, "lm_temperature": lm_temperature, "lm_cfg_scale": lm_cfg_scale, "lm_top_k": lm_top_k, "lm_top_p": lm_top_p, "lm_negative_prompt": lm_negative_prompt, "use_cot_metas": use_cot_metas, "use_cot_caption": use_cot_caption, "use_cot_language": use_cot_language, "repainting_group": repainting_group, "repainting_start": repainting_start, "repainting_end": repainting_end, "audio_cover_strength": audio_cover_strength, # Simple/Custom Mode Components "generation_mode": generation_mode, "simple_mode_group": simple_mode_group, "simple_query_input": simple_query_input, "random_desc_btn": random_desc_btn, "simple_instrumental_checkbox": simple_instrumental_checkbox, "simple_vocal_language": simple_vocal_language, "create_sample_btn": create_sample_btn, "simple_sample_created": simple_sample_created, "caption_accordion": caption_accordion, "lyrics_accordion": lyrics_accordion, "optional_params_accordion": optional_params_accordion, # Existing components "captions": captions, "sample_btn": sample_btn, "load_file": load_file, "lyrics": lyrics, "vocal_language": vocal_language, "bpm": bpm, "key_scale": key_scale, "time_signature": time_signature, "audio_duration": audio_duration, "batch_size_input": batch_size_input, "inference_steps": inference_steps, "guidance_scale": guidance_scale, "seed": seed, "random_seed_checkbox": random_seed_checkbox, "use_adg": use_adg, "cfg_interval_start": cfg_interval_start, "cfg_interval_end": cfg_interval_end, "shift": shift, "infer_method": infer_method, "custom_timesteps": custom_timesteps, "audio_format": audio_format, "think_checkbox": think_checkbox, "autogen_checkbox": autogen_checkbox, "generate_btn": generate_btn, "instrumental_checkbox": instrumental_checkbox, "format_btn": format_btn, "constrained_decoding_debug": constrained_decoding_debug, "score_scale": score_scale, "allow_lm_batch": allow_lm_batch, "auto_score": auto_score, "auto_lrc": auto_lrc, "lm_batch_chunk_size": lm_batch_chunk_size, }