Spaces:
Running
Running
| import argparse | |
| import re | |
| import ast | |
| import os | |
| import sys | |
| import toml | |
| from pathlib import Path | |
| from typing import List, Optional, Tuple | |
| # Load environment variables from .env or .env.example (if available) | |
| try: | |
| from dotenv import load_dotenv | |
| _current_file = os.path.abspath(__file__) | |
| _project_root = os.path.dirname(_current_file) | |
| _env_path = os.path.join(_project_root, '.env') | |
| _env_example_path = os.path.join(_project_root, '.env.example') | |
| if os.path.exists(_env_path): | |
| load_dotenv(_env_path) | |
| print(f"Loaded configuration from {_env_path}") | |
| elif os.path.exists(_env_example_path): | |
| load_dotenv(_env_example_path) | |
| print(f"Loaded configuration from {_env_example_path} (fallback)") | |
| except ImportError: | |
| pass | |
| # Clear proxy settings that may affect network behavior | |
| for _proxy_var in ['http_proxy', 'https_proxy', 'HTTP_PROXY', 'HTTPS_PROXY', 'ALL_PROXY']: | |
| os.environ.pop(_proxy_var, None) | |
| def _configure_logging( | |
| level: Optional[str] = None, | |
| suppress_audio_tokens: Optional[bool] = None, | |
| ) -> None: | |
| try: | |
| from loguru import logger | |
| except Exception: | |
| return | |
| if suppress_audio_tokens is None: | |
| suppress_audio_tokens = os.environ.get("ACE_STEP_SUPPRESS_AUDIO_TOKENS", "1") not in {"0", "false", "False"} | |
| if level is None: | |
| level = "INFO" | |
| level = str(level).upper() | |
| def _log_filter(record) -> bool: | |
| message = record.get("message", "") | |
| # Suppress duplicate DiT prompt logs (we print a single final prompt in cli.py) | |
| if ( | |
| "DiT TEXT ENCODER INPUT" in message | |
| or "text_prompt:" in message | |
| or (message.strip() and set(message.strip()) == {"="}) | |
| ): | |
| return False | |
| if not suppress_audio_tokens: | |
| return True | |
| return "<|audio_code_" not in message | |
| logger.remove() | |
| logger.add(sys.stderr, level=level, filter=_log_filter) | |
| _configure_logging() | |
| from acestep.handler import AceStepHandler | |
| from acestep.llm_inference import LLMHandler | |
| from acestep.inference import GenerationParams, GenerationConfig, generate_music, create_sample, format_sample | |
| from acestep.constants import DEFAULT_DIT_INSTRUCTION, TASK_INSTRUCTIONS | |
| from acestep.gpu_config import get_gpu_config, set_global_gpu_config, is_mps_platform | |
| import torch | |
| TRACK_CHOICES = [ | |
| "vocals", | |
| "backing_vocals", | |
| "drums", | |
| "bass", | |
| "guitar", | |
| "keyboard", | |
| "percussion", | |
| "strings", | |
| "synth", | |
| "fx", | |
| "brass", | |
| "woodwinds", | |
| ] | |
| def _get_project_root() -> str: | |
| return os.path.dirname(os.path.abspath(__file__)) | |
| def _parse_description_hints(description: str) -> tuple[Optional[str], bool]: | |
| import re | |
| if not description: | |
| return None, False | |
| description_lower = description.lower().strip() | |
| language_mapping = { | |
| 'english': 'en', 'en': 'en', | |
| 'chinese': 'zh', '中文': 'zh', 'zh': 'zh', 'mandarin': 'zh', | |
| 'japanese': 'ja', '日本語': 'ja', 'ja': 'ja', | |
| 'korean': 'ko', '한국어': 'ko', 'ko': 'ko', | |
| 'spanish': 'es', 'español': 'es', 'es': 'es', | |
| 'french': 'fr', 'français': 'fr', 'fr': 'fr', | |
| 'german': 'de', 'deutsch': 'de', 'de': 'de', | |
| 'italian': 'it', 'italiano': 'it', 'it': 'it', | |
| 'portuguese': 'pt', 'português': 'pt', 'pt': 'pt', | |
| 'russian': 'ru', 'русский': 'ru', 'ru': 'ru', | |
| 'bengali': 'bn', 'bn': 'bn', | |
| 'hindi': 'hi', 'hi': 'hi', | |
| 'arabic': 'ar', 'ar': 'ar', | |
| 'thai': 'th', 'th': 'th', | |
| 'vietnamese': 'vi', 'vi': 'vi', | |
| 'indonesian': 'id', 'id': 'id', | |
| 'turkish': 'tr', 'tr': 'tr', | |
| 'dutch': 'nl', 'nl': 'nl', | |
| 'polish': 'pl', 'pl': 'pl', | |
| } | |
| detected_language = None | |
| for lang_name, lang_code in language_mapping.items(): | |
| if len(lang_name) <= 2: | |
| pattern = r'(?:^|\s|[.,;:!?])' + re.escape(lang_name) + r'(?:$|\s|[.,;:!?])' | |
| else: | |
| pattern = r'\b' + re.escape(lang_name) + r'\b' | |
| if re.search(pattern, description_lower): | |
| detected_language = lang_code | |
| break | |
| is_instrumental = False | |
| if 'instrumental' in description_lower: | |
| is_instrumental = True | |
| elif 'pure music' in description_lower or 'pure instrument' in description_lower: | |
| is_instrumental = True | |
| elif description_lower.endswith(' solo') or description_lower == 'solo': | |
| is_instrumental = True | |
| return detected_language, is_instrumental | |
| def _prompt_non_empty(prompt: str) -> str: | |
| value = input(prompt).strip() | |
| while not value: | |
| value = input(prompt).strip() | |
| return value | |
| def _prompt_with_default(prompt: str, default: Optional[str] = None, required: bool = False) -> str: | |
| while True: | |
| suffix = f" [{default}]" if default not in (None, "") else "" | |
| value = input(f"{prompt}{suffix}: ").strip() | |
| if value: | |
| return value | |
| if default not in (None, ""): | |
| return str(default) | |
| if not required: | |
| return "" | |
| print("This value is required. Please try again.") | |
| def _prompt_bool(prompt: str, default: bool) -> bool: | |
| default_str = "y" if default else "n" | |
| while True: | |
| value = input(f"{prompt} (y/n) [default: {default_str}]: ").strip().lower() | |
| if not value: | |
| return default | |
| if value in {"y", "yes", "1", "true"}: | |
| return True | |
| if value in {"n", "no", "0", "false"}: | |
| return False | |
| print("Please enter 'y' or 'n'.") | |
| def _prompt_choice_from_list( | |
| prompt: str, | |
| options: List[str], | |
| default: Optional[str] = None, | |
| allow_custom: bool = True, | |
| custom_validator=None, | |
| custom_error: Optional[str] = None, | |
| ) -> Optional[str]: | |
| if not options: | |
| return default | |
| print("\n" + prompt) | |
| for idx, option in enumerate(options, start=1): | |
| print(f"{idx}. {option}") | |
| default_display = default if default not in (None, "") else "auto" | |
| while True: | |
| choice = input(f"Choose a model (number or name) [default: {default_display}]: ").strip() | |
| if not choice: | |
| return None if default_display == "auto" else default | |
| if choice.lower() == "auto": | |
| return None | |
| if choice.isdigit(): | |
| idx = int(choice) | |
| if 1 <= idx <= len(options): | |
| return options[idx - 1] | |
| print("Invalid selection. Please choose a valid number.") | |
| continue | |
| if allow_custom: | |
| if custom_validator and not custom_validator(choice): | |
| print(custom_error or "Invalid selection. Please try again.") | |
| continue | |
| if choice not in options: | |
| print("Unknown model. Using as-is.") | |
| return choice | |
| print("Please choose a valid option.") | |
| def _edit_formatted_prompt_via_file(formatted_prompt: str, instruction_path: str) -> str: | |
| """Write formatted prompt to file, wait for user edits, then read back.""" | |
| try: | |
| with open(instruction_path, "w", encoding="utf-8") as f: | |
| f.write(formatted_prompt) | |
| except Exception as e: | |
| print(f"WARNING: Failed to write {instruction_path}: {e}") | |
| return formatted_prompt | |
| print("\n--- Final Draft Saved ---") | |
| print(f"Saved to {instruction_path}") | |
| print("Edit the file now. Press Enter when ready to continue.") | |
| input() | |
| try: | |
| with open(instruction_path, "r", encoding="utf-8") as f: | |
| return f.read() | |
| except Exception as e: | |
| print(f"WARNING: Failed to read {instruction_path}: {e}") | |
| return formatted_prompt | |
| def _extract_caption_lyrics_from_formatted_prompt(formatted_prompt: str) -> Tuple[Optional[str], Optional[str]]: | |
| """Best-effort extraction of caption/lyrics from a formatted prompt string.""" | |
| matches = list(re.finditer(r"# Caption\n(.*?)\n+# Lyric\n(.*)", formatted_prompt, re.DOTALL)) | |
| if not matches: | |
| return None, None | |
| caption = matches[-1].group(1).strip() | |
| lyrics = matches[-1].group(2) | |
| # Trim lyrics if chat-template markers appear after the user message. | |
| cut_markers = ["<|eot_id|>", "<|start_header_id|>", "<|assistant|>", "<|user|>", "<|system|>", "<|im_end|>", "<|im_start|>"] | |
| cut_at = len(lyrics) | |
| for marker in cut_markers: | |
| pos = lyrics.find(marker) | |
| if pos != -1: | |
| cut_at = min(cut_at, pos) | |
| lyrics = lyrics[:cut_at].rstrip() | |
| return caption or None, lyrics or None | |
| def _extract_instruction_from_formatted_prompt(formatted_prompt: str) -> Optional[str]: | |
| """Best-effort extraction of instruction text from a formatted prompt string.""" | |
| match = re.search(r"# Instruction\n(.*?)\n\n", formatted_prompt, re.DOTALL) | |
| if not match: | |
| return None | |
| instruction = match.group(1).strip() | |
| return instruction or None | |
| def _extract_cot_metadata_from_formatted_prompt(formatted_prompt: str) -> dict: | |
| """Best-effort extraction of COT metadata from a formatted prompt string, | |
| supporting multi-line values. | |
| """ | |
| matches = list(re.finditer(r"<think>\n(.*?)\n</think>", formatted_prompt, re.DOTALL)) | |
| if not matches: | |
| return {} | |
| block = matches[-1].group(1) | |
| metadata = {} | |
| current_key = None | |
| current_value_lines = [] | |
| for line in block.splitlines(): | |
| line = line.strip() | |
| if not line: | |
| continue | |
| key_match = re.match(r"^(\w+):\s*(.*)", line) | |
| if key_match: | |
| if current_key: | |
| metadata[current_key] = " ".join(current_value_lines).strip() | |
| current_key = key_match.group(1).strip().lower() | |
| current_value_lines = [key_match.group(2).strip()] | |
| else: | |
| if current_key: | |
| current_value_lines.append(line) | |
| if current_key and current_value_lines: | |
| metadata[current_key] = " ".join(current_value_lines).strip() | |
| return metadata | |
| def _parse_number(value: str) -> Optional[float]: | |
| try: | |
| match = re.search(r"[-+]?\d*\.?\d+", value) | |
| if not match: | |
| return None | |
| return float(match.group(0)) | |
| except Exception: | |
| return None | |
| def _parse_timesteps_input(value) -> Optional[List[float]]: | |
| if value is None: | |
| return None | |
| if isinstance(value, list): | |
| if all(isinstance(t, (int, float)) for t in value): | |
| return [float(t) for t in value] | |
| return None | |
| if not isinstance(value, str): | |
| return None | |
| raw = value.strip() | |
| if not raw: | |
| return None | |
| if raw.startswith("[") or raw.startswith("("): | |
| try: | |
| parsed = ast.literal_eval(raw) | |
| except Exception: | |
| return None | |
| if isinstance(parsed, list) and all(isinstance(t, (int, float)) for t in parsed): | |
| return [float(t) for t in parsed] | |
| return None | |
| try: | |
| return [float(t.strip()) for t in raw.split(",") if t.strip()] | |
| except Exception: | |
| return None | |
| def _install_prompt_edit_hook( | |
| llm_handler: LLMHandler, | |
| instruction_path: str, | |
| preloaded_prompt: Optional[str] = None, | |
| ) -> None: | |
| """Intercept formatted prompt generation to allow user editing before audio tokens.""" | |
| original = llm_handler.build_formatted_prompt_with_cot | |
| cache = {} | |
| def wrapped(caption, lyrics, cot_text, is_negative_prompt=False, negative_prompt="NO USER INPUT"): | |
| prompt = original( | |
| caption, | |
| lyrics, | |
| cot_text, | |
| is_negative_prompt=is_negative_prompt, | |
| negative_prompt=negative_prompt, | |
| ) | |
| if is_negative_prompt: | |
| conditional_prompt = original( | |
| caption, | |
| lyrics, | |
| cot_text, | |
| is_negative_prompt=False, | |
| negative_prompt=negative_prompt, | |
| ) | |
| cached = cache.get(conditional_prompt) | |
| if cached and (cached.get("edited_caption") or cached.get("edited_lyrics")): | |
| edited_caption = cached.get("edited_caption") or caption | |
| edited_lyrics = cached.get("edited_lyrics") or lyrics | |
| return original( | |
| edited_caption, | |
| edited_lyrics, | |
| cot_text, | |
| is_negative_prompt=True, | |
| negative_prompt=negative_prompt, | |
| ) | |
| return prompt | |
| cached = cache.get(prompt) | |
| if cached: | |
| return cached["edited_prompt"] | |
| if getattr(llm_handler, "_skip_prompt_edit", False): | |
| cache[prompt] = { | |
| "edited_prompt": prompt, | |
| "edited_caption": None, | |
| "edited_lyrics": None, | |
| } | |
| return prompt | |
| if preloaded_prompt is not None: | |
| edited = preloaded_prompt | |
| else: | |
| edited = _edit_formatted_prompt_via_file(prompt, instruction_path) | |
| edited_caption, edited_lyrics = _extract_caption_lyrics_from_formatted_prompt(edited) | |
| if edited != prompt: | |
| print("INFO: Using edited draft for audio-token prompt.") | |
| if edited_caption or edited_lyrics: | |
| llm_handler._edited_caption = edited_caption | |
| llm_handler._edited_lyrics = edited_lyrics | |
| edited_instruction = _extract_instruction_from_formatted_prompt(edited) | |
| if edited_instruction: | |
| llm_handler._edited_instruction = edited_instruction | |
| edited_metas = _extract_cot_metadata_from_formatted_prompt(edited) | |
| if edited_metas: | |
| llm_handler._edited_metas = edited_metas | |
| cache[prompt] = { | |
| "edited_prompt": edited, | |
| "edited_caption": edited_caption, | |
| "edited_lyrics": edited_lyrics, | |
| } | |
| return edited | |
| llm_handler.build_formatted_prompt_with_cot = wrapped | |
| def _prompt_int(prompt: str, default: Optional[int] = None, min_value: Optional[int] = None, | |
| max_value: Optional[int] = None) -> Optional[int]: | |
| default_display = "auto" if default is None else default | |
| while True: | |
| value = input(f"{prompt} [{default_display}]: ").strip() | |
| if not value: | |
| return default | |
| try: | |
| parsed = int(value) | |
| except ValueError: | |
| print("Invalid input. Please enter an integer.") | |
| continue | |
| if min_value is not None and parsed < min_value: | |
| print(f"Please enter a value >= {min_value}.") | |
| continue | |
| if max_value is not None and parsed > max_value: | |
| print(f"Please enter a value <= {max_value}.") | |
| continue | |
| return parsed | |
| def _prompt_float(prompt: str, default: Optional[float] = None, min_value: Optional[float] = None, | |
| max_value: Optional[float] = None) -> Optional[float]: | |
| default_display = "auto" if default is None else default | |
| while True: | |
| value = input(f"{prompt} [{default_display}]: ").strip() | |
| if not value: | |
| return default | |
| try: | |
| parsed = float(value) | |
| except ValueError: | |
| print("Invalid input. Please enter a number.") | |
| continue | |
| if min_value is not None and parsed < min_value: | |
| print(f"Please enter a value >= {min_value}.") | |
| continue | |
| if max_value is not None and parsed > max_value: | |
| print(f"Please enter a value <= {max_value}.") | |
| continue | |
| return parsed | |
| def _prompt_existing_file(prompt: str, default: Optional[str] = None) -> str: | |
| while True: | |
| suffix = f" [{default}]" if default else "" | |
| path = input(f"{prompt}{suffix}: ").strip() | |
| if not path and default: | |
| path = default | |
| if os.path.isfile(path): | |
| return _expand_audio_path(path) | |
| print("Invalid file path. Please try again.") | |
| def _expand_audio_path(path_str: Optional[str]) -> Optional[str]: | |
| if not path_str or not isinstance(path_str, str): | |
| return path_str | |
| try: | |
| return Path(path_str).expanduser().resolve(strict=False).as_posix() | |
| except Exception: | |
| return Path(path_str).expanduser().absolute().as_posix() | |
| def _parse_bool(value: str) -> bool: | |
| return str(value).lower() in {"true", "1", "yes", "y"} | |
| def _resolve_device(device: str) -> str: | |
| if device == "auto": | |
| if hasattr(torch, 'xpu') and torch.xpu.is_available(): | |
| return "xpu" | |
| if torch.cuda.is_available(): | |
| return "cuda" | |
| if hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): | |
| return "mps" | |
| return "cpu" | |
| return device | |
| def _default_instruction_for_task(task_type: str, tracks: Optional[List[str]] = None) -> str: | |
| if task_type == "lego": | |
| track = tracks[0] if tracks else "guitar" | |
| return TASK_INSTRUCTIONS["lego"].format(TRACK_NAME=track.upper()) | |
| if task_type == "extract": | |
| track = tracks[0] if tracks else "vocals" | |
| return TASK_INSTRUCTIONS["extract"].format(TRACK_NAME=track.upper()) | |
| if task_type == "complete": | |
| tracks_list = ", ".join(tracks) if tracks else "drums, bass, guitar" | |
| return TASK_INSTRUCTIONS["complete"].format(TRACK_CLASSES=tracks_list) | |
| return DEFAULT_DIT_INSTRUCTION | |
| def _apply_optional_defaults(args, params_defaults: GenerationParams, config_defaults: GenerationConfig) -> None: | |
| optional_defaults = { | |
| "duration": params_defaults.duration, | |
| "bpm": params_defaults.bpm, | |
| "keyscale": params_defaults.keyscale, | |
| "timesignature": params_defaults.timesignature, | |
| "vocal_language": params_defaults.vocal_language, | |
| "inference_steps": params_defaults.inference_steps, | |
| "seed": params_defaults.seed, | |
| "guidance_scale": params_defaults.guidance_scale, | |
| "use_adg": params_defaults.use_adg, | |
| "cfg_interval_start": params_defaults.cfg_interval_start, | |
| "cfg_interval_end": params_defaults.cfg_interval_end, | |
| "shift": 3.0, | |
| "infer_method": params_defaults.infer_method, | |
| "timesteps": None, | |
| "repainting_start": params_defaults.repainting_start, | |
| "repainting_end": params_defaults.repainting_end, | |
| "audio_cover_strength": params_defaults.audio_cover_strength, | |
| "thinking": params_defaults.thinking, | |
| "lm_temperature": params_defaults.lm_temperature, | |
| "lm_cfg_scale": params_defaults.lm_cfg_scale, | |
| "lm_top_k": params_defaults.lm_top_k, | |
| "lm_top_p": params_defaults.lm_top_p, | |
| "lm_negative_prompt": params_defaults.lm_negative_prompt, | |
| "use_cot_metas": params_defaults.use_cot_metas, | |
| "use_cot_caption": params_defaults.use_cot_caption, | |
| "use_cot_lyrics": params_defaults.use_cot_lyrics, | |
| "use_cot_language": params_defaults.use_cot_language, | |
| "use_constrained_decoding": params_defaults.use_constrained_decoding, | |
| "batch_size": config_defaults.batch_size, | |
| "allow_lm_batch": config_defaults.allow_lm_batch, | |
| "use_random_seed": config_defaults.use_random_seed, | |
| "seeds": config_defaults.seeds, | |
| "lm_batch_chunk_size": config_defaults.lm_batch_chunk_size, | |
| "constrained_decoding_debug": config_defaults.constrained_decoding_debug, | |
| "audio_format": config_defaults.audio_format, | |
| "sample_mode": False, | |
| "sample_query": "", | |
| "use_format": False, | |
| } | |
| for key, default_value in optional_defaults.items(): | |
| if getattr(args, key, None) is None: | |
| setattr(args, key, default_value) | |
| def _summarize_lyrics(lyrics: Optional[str]) -> str: | |
| if not lyrics: | |
| return "none" | |
| if isinstance(lyrics, str): | |
| stripped = lyrics.strip() | |
| if not stripped: | |
| return "none" | |
| if os.path.isfile(stripped): | |
| return f"file: {os.path.basename(stripped)}" | |
| if len(stripped) <= 60: | |
| return stripped.replace("\n", " ") | |
| return f"text ({len(stripped)} chars)" | |
| return "provided" | |
| def _print_final_parameters( | |
| args, | |
| params: GenerationParams, | |
| config: GenerationConfig, | |
| params_defaults: GenerationParams, | |
| config_defaults: GenerationConfig, | |
| compact: bool, | |
| resolved_device: Optional[str] = None, | |
| ) -> None: | |
| if not compact: | |
| print("\n--- Final Parameters (Args) ---") | |
| for k in sorted(vars(args).keys()): | |
| print(f"{k}: {getattr(args, k)}") | |
| print("------------------------------") | |
| print("\n--- Final Parameters (GenerationParams) ---") | |
| for k in sorted(vars(params).keys()): | |
| print(f"{k}: {getattr(params, k)}") | |
| print("-------------------------------------------") | |
| print("\n--- Final Parameters (GenerationConfig) ---") | |
| for k in sorted(vars(config).keys()): | |
| print(f"{k}: {getattr(config, k)}") | |
| print("-------------------------------------------\n") | |
| return | |
| device_display = args.device | |
| if resolved_device and resolved_device != args.device: | |
| device_display = f"{args.device} -> {resolved_device}" | |
| print("\n--- Final Parameters (Summary) ---") | |
| print(f"task_type: {params.task_type}") | |
| print(f"caption: {params.caption or 'none'}") | |
| print(f"lyrics: {_summarize_lyrics(params.lyrics)}") | |
| print(f"duration: {params.duration}s") | |
| print(f"outputs: {config.batch_size}") | |
| if params.bpm not in (None, params_defaults.bpm): | |
| print(f"bpm: {params.bpm}") | |
| if params.keyscale not in (None, params_defaults.keyscale): | |
| print(f"keyscale: {params.keyscale}") | |
| if params.timesignature not in (None, params_defaults.timesignature): | |
| print(f"timesignature: {params.timesignature}") | |
| print(f"instrumental: {params.instrumental}") | |
| print(f"thinking: {params.thinking}") | |
| print(f"lm_model: {args.lm_model_path or 'auto'}") | |
| print(f"dit_model: {args.config_path or 'auto'}") | |
| print(f"backend: {args.backend}") | |
| print(f"device: {device_display}") | |
| print(f"audio_format: {config.audio_format}") | |
| print(f"save_dir: {args.save_dir}") | |
| if config.seeds: | |
| print(f"seeds: {config.seeds}") | |
| else: | |
| print(f"seed: {params.seed} (random={config.use_random_seed})") | |
| print("-------------------------------\n") | |
| def _build_meta_dict(params: GenerationParams) -> Optional[dict]: | |
| meta = {} | |
| if params.bpm is not None: | |
| meta["bpm"] = params.bpm | |
| if params.timesignature: | |
| meta["timesignature"] = params.timesignature | |
| if params.keyscale: | |
| meta["keyscale"] = params.keyscale | |
| if params.duration is not None: | |
| meta["duration"] = params.duration | |
| return meta or None | |
| def _print_dit_prompt(dit_handler: "AceStepHandler", params: GenerationParams) -> None: | |
| meta = _build_meta_dict(params) | |
| caption_input, lyrics_input = dit_handler.build_dit_inputs( | |
| task=params.task_type, | |
| instruction=params.instruction, | |
| caption=params.caption or "", | |
| lyrics=params.lyrics or "", | |
| metas=meta, | |
| vocal_language=params.vocal_language or "unknown", | |
| ) | |
| print("\n--- Final DiT Prompt (Caption Branch) ---") | |
| print(caption_input) | |
| print("\n--- Final DiT Prompt (Lyrics Branch) ---") | |
| print(lyrics_input) | |
| print("----------------------------------------\n") | |
| def run_wizard(args, configure_only: bool = False, default_config_path: Optional[str] = None, | |
| params_defaults: Optional[GenerationParams] = None, | |
| config_defaults: Optional[GenerationConfig] = None): | |
| """ | |
| Runs an interactive wizard to set generation parameters. | |
| """ | |
| print("Welcome to the ACE-Step Music Generation Wizard!") | |
| print("This will guide you through creating your music.") | |
| print("Press Ctrl+C at any time to exit.") | |
| print("Note: Required models will be auto-downloaded if missing.") | |
| print("-" * 30) | |
| try: | |
| # Task selection | |
| print("\n--- Task Type ---") | |
| print("1. text2music - generate music from text/lyrics.") | |
| print("2. cover - transform existing audio into a new style.") | |
| print("3. repaint - regenerate a specific time segment of audio.") | |
| print("4. lego - generate a specific instrument track in context.") | |
| print("5. extract - isolate a specific instrument track from a mix.") | |
| print("6. complete - complete/extend partial tracks with new instruments.") | |
| task_map = { | |
| "1": "text2music", | |
| "2": "cover", | |
| "3": "repaint", | |
| "4": "lego", | |
| "5": "extract", | |
| "6": "complete", | |
| } | |
| current_task = args.task_type or "text2music" | |
| task_default = next((k for k, v in task_map.items() if v == current_task), "1") | |
| task_choice = input(f"Choose a task (1-6) [default: {task_default}]: ").strip() | |
| if not task_choice: | |
| task_choice = task_default | |
| args.task_type = task_map.get(task_choice, "text2music") | |
| if args.task_type in {"lego", "extract", "complete"}: | |
| print("Note: This task requires a base DiT model (acestep-v15-base). It will be auto-downloaded if missing.") | |
| # Model selection (DiT) | |
| dit_handler = AceStepHandler() | |
| available_dit_models = dit_handler.get_available_acestep_v15_models() | |
| base_only = args.task_type in {"lego", "extract", "complete"} | |
| if base_only and available_dit_models: | |
| available_dit_models = [m for m in available_dit_models if "base" in m.lower()] | |
| if base_only and args.config_path and "base" not in str(args.config_path).lower(): | |
| args.config_path = None | |
| if base_only: | |
| if available_dit_models: | |
| if args.config_path in available_dit_models: | |
| selected = args.config_path | |
| else: | |
| selected = available_dit_models[0] | |
| args.config_path = selected | |
| print(f"\nNote: This task requires a base model. Using: {selected}") | |
| else: | |
| print("\nNote: This task requires a base model (e.g., 'acestep-v15-base'). It will be auto-downloaded if missing.") | |
| elif available_dit_models: | |
| selected = _prompt_choice_from_list( | |
| "--- Available DiT Models ---", | |
| available_dit_models, | |
| default=args.config_path, | |
| allow_custom=True, | |
| ) | |
| if selected is not None: | |
| args.config_path = selected | |
| else: | |
| print("\nNote: No local DiT models found. The main model will be auto-downloaded during initialization.") | |
| # Model selection (LM) | |
| llm_handler = LLMHandler() | |
| available_lm_models = llm_handler.get_available_5hz_lm_models() | |
| if available_lm_models: | |
| selected_lm = _prompt_choice_from_list( | |
| "--- Available LM Models ---", | |
| available_lm_models, | |
| default=args.lm_model_path, | |
| allow_custom=True, | |
| ) | |
| if selected_lm is not None: | |
| args.lm_model_path = selected_lm | |
| else: | |
| print("\nNote: No local LM models found. If LM features are enabled, a default LM will be auto-downloaded.") | |
| # Task-specific inputs | |
| if args.task_type in {"cover", "repaint", "lego", "extract", "complete"}: | |
| args.src_audio = _prompt_existing_file("Enter path to source audio file", default=args.src_audio) | |
| if args.task_type == "repaint": | |
| args.repainting_start = _prompt_float( | |
| "Repaint start time in seconds", args.repainting_start | |
| ) | |
| args.repainting_end = _prompt_float( | |
| "Repaint end time in seconds", args.repainting_end | |
| ) | |
| if args.task_type in {"lego", "extract"}: | |
| print("\nAvailable tracks:") | |
| print(", ".join(TRACK_CHOICES)) | |
| track_default = args.lego_track if args.task_type == "lego" else args.extract_track | |
| track = _prompt_with_default("Choose a track", track_default, required=True) | |
| if track not in TRACK_CHOICES: | |
| print("Unknown track. Using as-is.") | |
| if args.task_type == "lego": | |
| args.lego_track = track | |
| else: | |
| args.extract_track = track | |
| if not args.instruction or args.instruction == DEFAULT_DIT_INSTRUCTION: | |
| args.instruction = _default_instruction_for_task(args.task_type, [track]) | |
| args.instruction = _prompt_with_default("Instruction", args.instruction, required=True) | |
| if args.task_type == "complete": | |
| print("\nAvailable tracks:") | |
| print(", ".join(TRACK_CHOICES)) | |
| tracks_raw = _prompt_with_default("Choose tracks (comma-separated)", args.complete_tracks, required=True) | |
| tracks = [t.strip() for t in tracks_raw.split(",") if t.strip()] | |
| args.complete_tracks = ",".join(tracks) | |
| if not args.instruction or args.instruction == DEFAULT_DIT_INSTRUCTION: | |
| args.instruction = _default_instruction_for_task(args.task_type, tracks) | |
| args.instruction = _prompt_with_default("Instruction", args.instruction, required=True) | |
| if args.task_type in {"cover", "repaint", "lego", "complete"}: | |
| args.caption = _prompt_with_default( | |
| "Enter a music description (e.g., 'upbeat electronic dance music')", | |
| args.caption, | |
| required=True, | |
| ) | |
| elif args.task_type == "text2music": | |
| args.sample_mode = _prompt_bool("Use Simple Mode (auto-generate caption/lyrics via LM)", args.sample_mode) | |
| if args.sample_mode: | |
| args.sample_query = _prompt_with_default( | |
| "Describe the music you want (for auto-generation)", | |
| args.sample_query, | |
| required=False, | |
| ) | |
| if not args.sample_mode: | |
| caption = _prompt_with_default( | |
| "Enter a music description (optional if you provide lyrics)", | |
| args.caption, | |
| required=False, | |
| ) | |
| if caption: | |
| args.caption = caption | |
| # Lyrics | |
| if args.task_type in {"text2music", "cover", "repaint", "lego", "complete"} and not args.sample_mode: | |
| print("\n--- Lyrics Options ---") | |
| print("1. Instrumental (no lyrics).") | |
| print("2. Generate lyrics automatically.") | |
| print("3. Provide path to a .txt file.") | |
| print("4. Paste lyrics directly.") | |
| if args.instrumental or args.lyrics == "[Instrumental]": | |
| default_choice = "1" | |
| elif args.use_cot_lyrics: | |
| default_choice = "2" | |
| elif args.lyrics and isinstance(args.lyrics, str) and os.path.isfile(args.lyrics): | |
| default_choice = "3" | |
| elif args.lyrics: | |
| default_choice = "4" | |
| else: | |
| default_choice = "1" | |
| choice = input(f"Your choice (1-4) [default: {default_choice}]: ").strip() | |
| if not choice: | |
| choice = default_choice | |
| if choice == "1": # Instrumental | |
| args.instrumental = True | |
| args.lyrics = "[Instrumental]" | |
| args.use_cot_lyrics = False | |
| print("Instrumental music will be generated.") | |
| elif choice == "2": # Generate lyrics automatically | |
| args.use_cot_lyrics = True | |
| args.lyrics = "" | |
| args.instrumental = False | |
| print("Lyrics will be generated automatically.") | |
| elif choice == "3": | |
| args.instrumental = False | |
| args.use_cot_lyrics = False | |
| default_lyrics_path = args.lyrics if isinstance(args.lyrics, str) and os.path.isfile(args.lyrics) else None | |
| while True: | |
| lyrics_path = _prompt_existing_file("Please enter the path to your .txt lyrics file", default_lyrics_path) | |
| if lyrics_path.endswith('.txt'): | |
| args.lyrics = lyrics_path | |
| print(f"Lyrics will be loaded from: {lyrics_path}") | |
| break | |
| print("Invalid file path or not a .txt file. Please try again.") | |
| elif choice == "4": | |
| args.instrumental = False | |
| args.use_cot_lyrics = False | |
| default_lyrics = args.lyrics if isinstance(args.lyrics, str) and args.lyrics and not os.path.isfile(args.lyrics) else None | |
| args.lyrics = _prompt_with_default("Paste lyrics (single line or use \\n)", default_lyrics, required=True) | |
| if not args.instrumental: | |
| lang = _prompt_with_default( | |
| "Vocal language (e.g., 'en', 'zh', 'unknown')", | |
| args.vocal_language, | |
| required=False | |
| ).lower() | |
| if lang: | |
| args.vocal_language = lang | |
| if args.use_cot_lyrics: | |
| if not args.caption: | |
| args.caption = _prompt_non_empty("Enter a music description for lyric generation: ") | |
| if not args.thinking: | |
| print("INFO: Automatic lyric generation requires the LM handler. Enabling LM 'thinking'.") | |
| args.thinking = True | |
| args.batch_size = _prompt_int( | |
| "Number of outputs (audio clips) to generate", | |
| args.batch_size if args.batch_size is not None else 2, | |
| min_value=1, | |
| ) | |
| advanced = input("\nConfigure advanced parameters? (y/n) [default: n]: ").lower() | |
| if advanced == 'y': | |
| if args.task_type == "text2music" and not args.sample_mode: | |
| args.use_format = _prompt_bool("Use format_sample to enhance caption/lyrics", args.use_format) | |
| print("\n--- Optional Metadata ---") | |
| args.duration = _prompt_float("Duration in seconds (10-600)", args.duration, min_value=10, max_value=600) | |
| args.bpm = _prompt_int("BPM (30-300, empty for auto)", args.bpm, min_value=30, max_value=300) | |
| args.keyscale = _prompt_with_default("Keyscale (e.g., 'C Major', empty for auto)", args.keyscale) | |
| args.timesignature = _prompt_with_default("Time signature (e.g., '4/4', empty for auto)", args.timesignature) | |
| args.vocal_language = _prompt_with_default("Vocal language (e.g., 'en', 'zh', 'unknown')", args.vocal_language) | |
| print("\n--- Advanced DiT Settings ---") | |
| args.seed = _prompt_int("Random seed (-1 for random)", args.seed) | |
| args.inference_steps = _prompt_int("Inference steps", args.inference_steps, min_value=1) | |
| if args.config_path and 'base' in args.config_path: | |
| args.guidance_scale = _prompt_float("Guidance scale (for base models)", args.guidance_scale) | |
| args.use_adg = _prompt_bool("Enable Adaptive Dual Guidance (ADG)", args.use_adg) | |
| args.cfg_interval_start = _prompt_float("CFG interval start (0.0-1.0)", args.cfg_interval_start, 0.0, 1.0) | |
| args.cfg_interval_end = _prompt_float("CFG interval end (0.0-1.0)", args.cfg_interval_end, 0.0, 1.0) | |
| args.shift = _prompt_float("Timestep shift (1.0-5.0)", args.shift, 1.0, 5.0) | |
| args.infer_method = _prompt_with_default("Inference method (ode/sde)", args.infer_method) | |
| timesteps_input = _prompt_with_default( | |
| "Custom timesteps list (e.g., [0.97, 0.5, 0])", | |
| args.timesteps, | |
| required=False, | |
| ) | |
| if timesteps_input: | |
| args.timesteps = timesteps_input | |
| if args.task_type == "cover": | |
| args.audio_cover_strength = _prompt_float( | |
| "Audio cover strength (0.0-1.0)", args.audio_cover_strength, 0.0, 1.0 | |
| ) | |
| print("\n--- Advanced LM Settings ---") | |
| args.thinking = _prompt_bool("Enable LM 'thinking'", args.thinking) | |
| args.lm_temperature = _prompt_float("LM temperature (0.0-2.0)", args.lm_temperature, 0.0, 2.0) | |
| args.lm_cfg_scale = _prompt_float("LM CFG scale", args.lm_cfg_scale) | |
| args.lm_top_k = _prompt_int("LM top-k (0 disables)", args.lm_top_k, min_value=0) | |
| args.lm_top_p = _prompt_float("LM top-p (0.0-1.0)", args.lm_top_p, 0.0, 1.0) | |
| args.lm_negative_prompt = _prompt_with_default("LM negative prompt", args.lm_negative_prompt) | |
| args.use_cot_metas = _prompt_bool("Use CoT for metadata", args.use_cot_metas) | |
| args.use_cot_caption = _prompt_bool("Use CoT for caption refinement", args.use_cot_caption) | |
| args.use_cot_lyrics = _prompt_bool("Use CoT for lyrics generation", args.use_cot_lyrics) | |
| args.use_cot_language = _prompt_bool("Use CoT for language detection", args.use_cot_language) | |
| args.use_constrained_decoding = _prompt_bool("Use constrained decoding", args.use_constrained_decoding) | |
| print("\n--- Output Settings ---") | |
| args.save_dir = _prompt_with_default("Save directory", args.save_dir) | |
| args.audio_format = _prompt_with_default("Audio format (mp3/wav/flac)", args.audio_format) | |
| # Batch size already captured above. | |
| args.use_random_seed = _prompt_bool("Use random seed per batch", args.use_random_seed) | |
| seeds_input = _prompt_with_default( | |
| "Custom seeds (comma/space separated, leave empty for random)", | |
| "", | |
| required=False, | |
| ) | |
| if seeds_input: | |
| seeds = [s for s in seeds_input.replace(",", " ").split() if s.strip()] | |
| try: | |
| args.seeds = [int(s) for s in seeds] | |
| except ValueError: | |
| print("Invalid seeds input. Ignoring custom seeds.") | |
| args.allow_lm_batch = _prompt_bool("Allow LM batch processing", args.allow_lm_batch) | |
| args.lm_batch_chunk_size = _prompt_int("LM batch chunk size", args.lm_batch_chunk_size, min_value=1) | |
| args.constrained_decoding_debug = _prompt_bool("Constrained decoding debug", args.constrained_decoding_debug) | |
| else: | |
| if params_defaults and config_defaults: | |
| _apply_optional_defaults(args, params_defaults, config_defaults) | |
| # Ensure LM thinking is enabled when lyric generation is requested. | |
| if args.use_cot_lyrics and not args.thinking: | |
| print("INFO: Automatic lyric generation requires the LM handler. Enabling LM 'thinking'.") | |
| args.thinking = True | |
| print("\n--- Summary ---") | |
| print(f"Task: {args.task_type}") | |
| if args.caption: | |
| print(f"Description: {args.caption}") | |
| if args.task_type in {"lego", "extract", "complete"}: | |
| print(f"Instruction: {args.instruction}") | |
| if args.src_audio: | |
| print(f"Source audio: {args.src_audio}") | |
| print(f"Duration: {args.duration}s") | |
| print(f"Outputs: {args.batch_size}") | |
| if args.instrumental: | |
| print("Lyrics: Instrumental") | |
| elif args.use_cot_lyrics: | |
| print(f"Lyrics: Auto-generated ({args.vocal_language})") | |
| elif args.lyrics and os.path.isfile(args.lyrics): | |
| print(f"Lyrics: Provided from file ({args.lyrics})") | |
| elif args.lyrics: | |
| print(f"Lyrics: Provided as text") | |
| print("-" * 30) | |
| if not configure_only: | |
| confirm = input("Start generation with these settings? (y/n) [default: y]: ").lower() | |
| if confirm == 'n': | |
| print("Generation cancelled.") | |
| sys.exit(0) | |
| default_filename = default_config_path or "config.toml" | |
| config_filename = input(f"\nEnter filename to save configuration [{default_filename}]: ") | |
| if not config_filename: | |
| config_filename = default_filename | |
| if not config_filename.endswith(".toml"): | |
| config_filename += ".toml" | |
| try: | |
| config_to_save = { | |
| k: v for k, v in vars(args).items() | |
| if k not in ['config'] and not k.startswith('_') | |
| } | |
| with open(config_filename, 'w') as f: | |
| toml.dump(config_to_save, f) | |
| print(f"Configuration saved to {config_filename}") | |
| print(f"You can reuse it next time with: python cli.py -c {config_filename}") | |
| except Exception as e: | |
| print(f"Error saving configuration: {e}. Please try again.") | |
| except (KeyboardInterrupt, EOFError): | |
| print("\nWizard cancelled. Exiting.") | |
| sys.exit(0) | |
| return args, not configure_only | |
| def main(): | |
| """ | |
| Main function to run ACE-Step music generation from the command line. | |
| """ | |
| gpu_config = get_gpu_config() | |
| set_global_gpu_config(gpu_config) | |
| mps_available = is_mps_platform() | |
| # Mac (Apple Silicon) uses unified memory — offloading provides no benefit | |
| auto_offload = (not mps_available) and gpu_config.gpu_memory_gb > 0 and gpu_config.gpu_memory_gb < 16 | |
| print(f"\n{'='*60}") | |
| print("GPU Configuration Detected:") | |
| print(f"{'='*60}") | |
| print(f" GPU Memory: {gpu_config.gpu_memory_gb:.2f} GiB") | |
| print(f" Configuration Tier: {gpu_config.tier}") | |
| print(f" Max Duration (with LM): {gpu_config.max_duration_with_lm}s ({gpu_config.max_duration_with_lm // 60} min)") | |
| print(f" Max Duration (without LM): {gpu_config.max_duration_without_lm}s ({gpu_config.max_duration_without_lm // 60} min)") | |
| print(f" Max Batch Size (with LM): {gpu_config.max_batch_size_with_lm}") | |
| print(f" Max Batch Size (without LM): {gpu_config.max_batch_size_without_lm}") | |
| print(f" Default LM Init: {gpu_config.init_lm_default}") | |
| print(f" Available LM Models: {gpu_config.available_lm_models or 'None'}") | |
| print(f"{'='*60}\n") | |
| if auto_offload: | |
| print("Auto-enabling CPU offload (GPU < 16GB)") | |
| elif gpu_config.gpu_memory_gb > 0: | |
| print("CPU offload disabled by default (GPU >= 16GB)") | |
| elif mps_available: | |
| print("MPS detected, running on Apple GPU") | |
| else: | |
| print("No GPU detected, running on CPU") | |
| params_defaults = GenerationParams() | |
| config_defaults = GenerationConfig() | |
| parser = argparse.ArgumentParser( | |
| description="ACE-Step 1.5: Music generation (wizard/config only).", | |
| formatter_class=argparse.ArgumentDefaultsHelpFormatter | |
| ) | |
| parser.add_argument("-c", "--config", type=str, help="Path to a TOML configuration file to load.") | |
| parser.add_argument("--configure", action="store_true", help="Run wizard to save configuration without generating.") | |
| parser.add_argument( | |
| "--backend", | |
| type=str, | |
| default=None, | |
| choices=["vllm", "pt", "mlx"], | |
| help="5Hz LM backend. Auto-detected if not specified: 'mlx' on Apple Silicon, 'vllm' on CUDA, 'pt' otherwise.", | |
| ) | |
| parser.add_argument( | |
| "--log-level", | |
| type=str, | |
| default="INFO", | |
| help="Logging level for internal modules (TRACE/DEBUG/INFO/WARNING/ERROR/CRITICAL).", | |
| ) | |
| cli_args = parser.parse_args() | |
| _configure_logging(level=cli_args.log_level) | |
| default_batch_size = 1 if not cli_args.config else config_defaults.batch_size | |
| # Auto-detect MLX on Apple Silicon, fall back to vllm | |
| if mps_available: | |
| try: | |
| import mlx.core # noqa: F401 | |
| default_backend = "mlx" | |
| print("Apple Silicon detected with MLX available. Using MLX backend.") | |
| except ImportError: | |
| default_backend = "vllm" | |
| else: | |
| default_backend = "vllm" | |
| defaults = { | |
| "project_root": _get_project_root(), | |
| "config_path": None, | |
| "checkpoint_dir": os.path.join(_get_project_root(), "checkpoints"), | |
| "lm_model_path": None, | |
| "backend": default_backend, | |
| "device": "auto", | |
| "use_flash_attention": None, | |
| "offload_to_cpu": auto_offload, | |
| "offload_dit_to_cpu": False, | |
| "save_dir": "output", | |
| "audio_format": config_defaults.audio_format, | |
| "caption": "", | |
| "prompt": "", | |
| "lyrics": None, | |
| "duration": params_defaults.duration, | |
| "instrumental": False, | |
| "bpm": params_defaults.bpm, | |
| "keyscale": params_defaults.keyscale, | |
| "timesignature": params_defaults.timesignature, | |
| "vocal_language": params_defaults.vocal_language, | |
| "task_type": params_defaults.task_type, | |
| "instruction": params_defaults.instruction, | |
| "reference_audio": params_defaults.reference_audio, | |
| "src_audio": params_defaults.src_audio, | |
| "repainting_start": params_defaults.repainting_start, | |
| "repainting_end": params_defaults.repainting_end, | |
| "audio_cover_strength": params_defaults.audio_cover_strength, | |
| "lego_track": "", | |
| "extract_track": "", | |
| "complete_tracks": "", | |
| "sample_mode": False, | |
| "sample_query": "", | |
| "use_format": False, | |
| "inference_steps": params_defaults.inference_steps, | |
| "seed": params_defaults.seed, | |
| "guidance_scale": params_defaults.guidance_scale, | |
| "use_adg": params_defaults.use_adg, | |
| "shift": 3.0, | |
| "infer_method": params_defaults.infer_method, | |
| "timesteps": None, | |
| "thinking": gpu_config.init_lm_default, | |
| "lm_temperature": params_defaults.lm_temperature, | |
| "lm_cfg_scale": params_defaults.lm_cfg_scale, | |
| "lm_top_k": params_defaults.lm_top_k, | |
| "lm_top_p": params_defaults.lm_top_p, | |
| "use_cot_metas": params_defaults.use_cot_metas, | |
| "use_cot_caption": params_defaults.use_cot_caption, | |
| "use_cot_lyrics": params_defaults.use_cot_lyrics, | |
| "use_cot_language": params_defaults.use_cot_language, | |
| "use_constrained_decoding": params_defaults.use_constrained_decoding, | |
| "batch_size": default_batch_size, | |
| "seeds": None, | |
| "use_random_seed": config_defaults.use_random_seed, | |
| "allow_lm_batch": config_defaults.allow_lm_batch, | |
| "lm_batch_chunk_size": config_defaults.lm_batch_chunk_size, | |
| "constrained_decoding_debug": config_defaults.constrained_decoding_debug, | |
| "audio_codes": "", | |
| "cfg_interval_start": params_defaults.cfg_interval_start, | |
| "cfg_interval_end": params_defaults.cfg_interval_end, | |
| "lm_negative_prompt": params_defaults.lm_negative_prompt, | |
| "log_level": cli_args.log_level, | |
| } | |
| args = argparse.Namespace(**defaults) | |
| args.config = None | |
| if cli_args.config: | |
| if not os.path.exists(cli_args.config): | |
| parser.error(f"Config file not found: {cli_args.config}") | |
| try: | |
| with open(cli_args.config, 'r') as f: | |
| config_from_file = toml.load(f) | |
| print(f"Configuration loaded from {cli_args.config}") | |
| except Exception as e: | |
| parser.error(f"Error loading TOML config file {cli_args.config}: {e}") | |
| for key, value in config_from_file.items(): | |
| setattr(args, key, value) | |
| args.config = cli_args.config | |
| # CLI --backend overrides config file and auto-detection | |
| if cli_args.backend is not None: | |
| args.backend = cli_args.backend | |
| if cli_args.configure: | |
| args, _ = run_wizard( | |
| args, | |
| configure_only=True, | |
| default_config_path=cli_args.config, | |
| params_defaults=params_defaults, | |
| config_defaults=config_defaults, | |
| ) | |
| print("Configuration complete. Exiting without generation.") | |
| sys.exit(0) | |
| if not cli_args.config: | |
| args, should_generate = run_wizard( | |
| args, | |
| configure_only=False, | |
| default_config_path=None, | |
| params_defaults=params_defaults, | |
| config_defaults=config_defaults, | |
| ) | |
| if not should_generate: | |
| print("Configuration complete. Exiting without generation.") | |
| sys.exit(0) | |
| # --- Post-parsing Setup --- | |
| if args.use_cot_lyrics and not args.thinking: | |
| print("INFO: Automatic lyric generation requires the LM handler. Forcing --thinking=True.") | |
| args.thinking = True | |
| if not args.project_root: | |
| args.project_root = _get_project_root() | |
| else: | |
| args.project_root = os.path.abspath(os.path.expanduser(str(args.project_root))) | |
| if args.checkpoint_dir: | |
| args.checkpoint_dir = os.path.expanduser(str(args.checkpoint_dir)) | |
| if not os.path.isabs(args.checkpoint_dir): | |
| args.checkpoint_dir = os.path.join(args.project_root, args.checkpoint_dir) | |
| if args.src_audio: | |
| args.src_audio = _expand_audio_path(args.src_audio) | |
| if args.reference_audio: | |
| args.reference_audio = _expand_audio_path(args.reference_audio) | |
| device = _resolve_device(args.device) | |
| # --- Argument Post-processing --- | |
| try: | |
| timesteps = _parse_timesteps_input(args.timesteps) | |
| if args.timesteps and timesteps is None: | |
| raise ValueError("Timesteps must be a list of numbers or a comma-separated string.") | |
| except ValueError as e: | |
| parser.error(f"Invalid format for timesteps. Expected a list of numbers (e.g., '[1.0, 0.5, 0.0]' or '0.97,0.5,0'). Error: {e}") | |
| if args.seeds: | |
| args.batch_size = len(args.seeds) | |
| args.use_random_seed = False | |
| args.seed = -1 | |
| if args.instrumental and not args.lyrics: | |
| args.lyrics = "[Instrumental]" | |
| elif isinstance(args.lyrics, str) and args.lyrics.strip().lower() in {"[inst]", "[instrumental]"}: | |
| args.instrumental = True | |
| # --- Task-specific validation and instruction helpers --- | |
| if args.task_type in {"cover", "repaint", "lego", "extract", "complete"}: | |
| if not args.src_audio: | |
| parser.error(f"--src_audio is required for task_type '{args.task_type}'.") | |
| if args.task_type in {"cover", "repaint", "lego", "complete"}: | |
| if not args.caption: | |
| parser.error(f"--caption is required for task_type '{args.task_type}'.") | |
| if args.task_type == "text2music": | |
| if not args.caption and not args.lyrics: | |
| if not args.sample_mode and not args.sample_query: | |
| parser.error("--caption or --lyrics is required for text2music.") | |
| if args.use_cot_lyrics and not args.caption: | |
| parser.error("--use_cot_lyrics requires --caption for lyric generation.") | |
| if args.sample_mode or args.sample_query: | |
| args.sample_mode = True | |
| else: | |
| if args.sample_mode or args.sample_query: | |
| parser.error("--sample_mode/sample_query are only supported for task_type 'text2music'.") | |
| if args.sample_mode and args.use_cot_lyrics: | |
| print("INFO: sample_mode enabled. Disabling --use_cot_lyrics.") | |
| args.use_cot_lyrics = False | |
| # Auto-select instruction based on task_type if user didn't provide a custom instruction. | |
| # Align with api_server behavior and TASK_INSTRUCTIONS defaults. | |
| if args.instruction == DEFAULT_DIT_INSTRUCTION and args.task_type in TASK_INSTRUCTIONS: | |
| if args.task_type in {"text2music", "cover", "repaint"}: | |
| args.instruction = TASK_INSTRUCTIONS[args.task_type] | |
| # Base-model-only task enforcement | |
| base_only_tasks = {"lego", "extract", "complete"} | |
| if args.task_type in base_only_tasks and args.config_path: | |
| if "base" not in str(args.config_path).lower(): | |
| parser.error(f"task_type '{args.task_type}' requires a base model config (e.g., 'acestep-v15-base').") | |
| if args.task_type == "repaint": | |
| if args.repainting_end != -1 and args.repainting_end <= args.repainting_start: | |
| parser.error("--repainting_end must be greater than --repainting_start (or -1).") | |
| if args.task_type in {"lego", "extract", "complete"}: | |
| has_custom_instruction = bool(args.instruction and args.instruction.strip() and args.instruction.strip() != params_defaults.instruction) | |
| if not has_custom_instruction: | |
| if args.task_type == "lego": | |
| if not args.lego_track: | |
| parser.error("--instruction or --lego_track is required for lego task.") | |
| args.instruction = _default_instruction_for_task("lego", [args.lego_track.strip()]) | |
| elif args.task_type == "extract": | |
| if not args.extract_track: | |
| parser.error("--instruction or --extract_track is required for extract task.") | |
| args.instruction = _default_instruction_for_task("extract", [args.extract_track.strip()]) | |
| elif args.task_type == "complete": | |
| if not args.complete_tracks: | |
| parser.error("--instruction or --complete_tracks is required for complete task.") | |
| tracks = [t.strip() for t in args.complete_tracks.split(",") if t.strip()] | |
| if not tracks: | |
| parser.error("--complete_tracks must contain at least one track.") | |
| args.instruction = _default_instruction_for_task("complete", tracks) | |
| # Handle lyrics argument | |
| lyrics_arg = args.lyrics | |
| if isinstance(lyrics_arg, str) and lyrics_arg: | |
| lyrics_arg = os.path.expanduser(lyrics_arg) | |
| if not os.path.isabs(lyrics_arg): | |
| # Resolve relative lyrics path against config file location first, then project_root. | |
| resolved = None | |
| if args.config: | |
| config_dir = os.path.dirname(os.path.abspath(args.config)) | |
| candidate = os.path.join(config_dir, lyrics_arg) | |
| if os.path.isfile(candidate): | |
| resolved = candidate | |
| if resolved is None and args.project_root: | |
| candidate = os.path.join(os.path.abspath(args.project_root), lyrics_arg) | |
| if os.path.isfile(candidate): | |
| resolved = candidate | |
| if resolved is not None: | |
| lyrics_arg = resolved | |
| if lyrics_arg is not None: | |
| if lyrics_arg == "generate": | |
| args.use_cot_lyrics = True | |
| args.lyrics = "" | |
| print("Lyrics generation enabled.") | |
| elif os.path.isfile(lyrics_arg): | |
| print(f"INFO: Attempting to load lyrics from file: {lyrics_arg}") | |
| try: | |
| with open(lyrics_arg, 'r', encoding='utf-8') as f: | |
| args.lyrics = f.read() | |
| print(f"Lyrics loaded from file: {lyrics_arg}") | |
| except Exception as e: | |
| parser.error(f"Could not read lyrics file {lyrics_arg}. Error: {e}") | |
| # else: lyrics is a string, use as is. | |
| # --- Handler Initialization --- | |
| if args.backend == "pyTorch": | |
| args.backend = "pt" | |
| if args.backend not in {"vllm", "pt", "mlx"}: | |
| args.backend = "vllm" | |
| print("Initializing ACE-Step handlers...") | |
| dit_handler = AceStepHandler() | |
| llm_handler = LLMHandler() | |
| base_only_tasks = {"lego", "extract", "complete"} | |
| skip_lm_tasks = {"cover", "repaint"} | |
| requires_lm = ( | |
| args.task_type not in skip_lm_tasks and ( | |
| args.thinking | |
| or args.sample_mode | |
| or bool(args.sample_query and str(args.sample_query).strip()) | |
| or args.use_format | |
| or args.use_cot_metas | |
| or args.use_cot_caption | |
| or args.use_cot_lyrics | |
| or args.use_cot_language | |
| ) | |
| ) | |
| if args.config_path is None: | |
| available_models = dit_handler.get_available_acestep_v15_models() | |
| if args.task_type in base_only_tasks and available_models: | |
| available_models = [m for m in available_models if "base" in m.lower()] | |
| if not available_models: | |
| print("No DiT models found. Downloading main model (acestep-v15-turbo + core components)...") | |
| from acestep.model_downloader import ensure_main_model, get_checkpoints_dir | |
| checkpoints_dir = get_checkpoints_dir() | |
| success, msg = ensure_main_model(checkpoints_dir) | |
| print(msg) | |
| if not success: | |
| parser.error(f"Failed to download main model: {msg}") | |
| available_models = dit_handler.get_available_acestep_v15_models() | |
| if args.task_type in base_only_tasks and available_models: | |
| available_models = [m for m in available_models if "base" in m.lower()] | |
| if args.task_type in base_only_tasks and not available_models: | |
| print("Base-only task selected. Downloading base DiT model (acestep-v15-base)...") | |
| from acestep.model_downloader import ensure_dit_model, get_checkpoints_dir | |
| checkpoints_dir = get_checkpoints_dir() | |
| success, msg = ensure_dit_model("acestep-v15-base", checkpoints_dir) | |
| print(msg) | |
| if not success: | |
| parser.error(f"Failed to download base DiT model: {msg}") | |
| available_models = dit_handler.get_available_acestep_v15_models() | |
| if available_models: | |
| available_models = [m for m in available_models if "base" in m.lower()] | |
| if available_models: | |
| if args.task_type in {"lego", "extract", "complete"}: | |
| preferred = "acestep-v15-base" | |
| else: | |
| preferred = "acestep-v15-turbo" | |
| args.config_path = preferred if preferred in available_models else available_models[0] | |
| print(f"Auto-selected config_path: {args.config_path}") | |
| else: | |
| parser.error("No available DiT models found. Please specify --config_path.") | |
| if args.task_type in {"lego", "extract", "complete"} and "base" not in str(args.config_path).lower(): | |
| parser.error(f"task_type '{args.task_type}' requires a base model config (e.g., 'acestep-v15-base').") | |
| # Ensure required DiT/main models are present for the selected task/model. | |
| from acestep.model_downloader import ( | |
| ensure_main_model, | |
| ensure_dit_model, | |
| get_checkpoints_dir, | |
| check_main_model_exists, | |
| check_model_exists, | |
| SUBMODEL_REGISTRY, | |
| ) | |
| checkpoints_dir = get_checkpoints_dir() | |
| if not check_main_model_exists(checkpoints_dir): | |
| print("Main model components not found. Downloading main model...") | |
| success, msg = ensure_main_model(checkpoints_dir) | |
| print(msg) | |
| if not success: | |
| parser.error(f"Failed to download main model: {msg}") | |
| if args.config_path: | |
| config_name = str(args.config_path) | |
| known_models = {"acestep-v15-turbo"} | set(SUBMODEL_REGISTRY.keys()) | |
| if check_model_exists(config_name, checkpoints_dir): | |
| pass | |
| elif config_name in known_models: | |
| success, msg = ensure_dit_model(config_name, checkpoints_dir) | |
| if not success: | |
| parser.error(f"Failed to download DiT model '{config_name}': {msg}") | |
| else: | |
| print(f"Warning: DiT model '{config_name}' not found locally and not in registry. Skipping auto-download.") | |
| use_flash_attention = args.use_flash_attention | |
| if use_flash_attention is None: | |
| use_flash_attention = dit_handler.is_flash_attention_available(device) | |
| compile_model = os.environ.get("ACESTEP_COMPILE_MODEL", "").strip().lower() in { | |
| "1", "true", "yes", "y", "on", | |
| } | |
| print(f"Initializing DiT handler with model: {args.config_path}") | |
| dit_handler.initialize_service( | |
| project_root=args.project_root, | |
| config_path=args.config_path, | |
| device=device, | |
| use_flash_attention=use_flash_attention, | |
| compile_model=compile_model, | |
| offload_to_cpu=args.offload_to_cpu, | |
| offload_dit_to_cpu=args.offload_dit_to_cpu, | |
| ) | |
| if requires_lm: | |
| from acestep.model_downloader import ensure_lm_model | |
| if args.lm_model_path is None: | |
| available_lm_models = llm_handler.get_available_5hz_lm_models() | |
| if available_lm_models: | |
| args.lm_model_path = available_lm_models[0] | |
| print(f"Using default LM model: {args.lm_model_path}") | |
| else: | |
| success, msg = ensure_lm_model(checkpoints_dir=checkpoints_dir) | |
| print(msg) | |
| if not success: | |
| parser.error("No LM models available. Please specify --lm_model_path or disable --thinking.") | |
| available_lm_models = llm_handler.get_available_5hz_lm_models() | |
| if not available_lm_models: | |
| parser.error("No LM models available after download. Please specify --lm_model_path or disable --thinking.") | |
| args.lm_model_path = available_lm_models[0] | |
| print(f"Using default LM model: {args.lm_model_path}") | |
| else: | |
| lm_model_path = str(args.lm_model_path) | |
| if os.path.isabs(lm_model_path) and os.path.exists(lm_model_path): | |
| pass | |
| elif check_model_exists(lm_model_path, checkpoints_dir): | |
| pass | |
| elif lm_model_path in SUBMODEL_REGISTRY: | |
| success, msg = ensure_lm_model(lm_model_path, checkpoints_dir=checkpoints_dir) | |
| print(msg) | |
| if not success: | |
| parser.error(f"Failed to download LM model '{lm_model_path}': {msg}") | |
| else: | |
| parser.error(f"LM model '{lm_model_path}' not found locally and not in registry. Please provide a valid --lm_model_path.") | |
| print(f"Initializing LM handler with model: {args.lm_model_path}") | |
| llm_handler.initialize( | |
| checkpoint_dir=args.checkpoint_dir, | |
| lm_model_path=args.lm_model_path, | |
| backend=args.backend, | |
| device=device, | |
| offload_to_cpu=args.offload_to_cpu, | |
| dtype=None, | |
| ) | |
| else: | |
| if args.task_type in skip_lm_tasks: | |
| print(f"LM is not required for task_type '{args.task_type}'. Skipping LM handler initialization.") | |
| else: | |
| print("LM 'thinking' is disabled. Skipping LM handler initialization.") | |
| print("Handlers initialized.") | |
| format_has_duration = False | |
| # --- Sample Mode / Description-based Auto-Generation --- | |
| if args.sample_mode or (args.sample_query and str(args.sample_query).strip()): | |
| if not llm_handler.llm_initialized: | |
| parser.error("--sample_mode/sample_query requires the LM handler, but it's not initialized.") | |
| sample_query = args.sample_query if args.sample_query and str(args.sample_query).strip() else "NO USER INPUT" | |
| parsed_language, parsed_instrumental = _parse_description_hints(sample_query) | |
| if args.vocal_language and args.vocal_language not in ("en", "unknown", ""): | |
| sample_language = args.vocal_language | |
| else: | |
| sample_language = parsed_language | |
| print("\nINFO: Creating sample via 'create_sample'...") | |
| sample_result = create_sample( | |
| llm_handler=llm_handler, | |
| query=sample_query, | |
| instrumental=parsed_instrumental, | |
| vocal_language=sample_language, | |
| temperature=args.lm_temperature, | |
| top_k=args.lm_top_k, | |
| top_p=args.lm_top_p, | |
| ) | |
| if sample_result.success: | |
| args.caption = sample_result.caption | |
| args.lyrics = sample_result.lyrics | |
| args.instrumental = bool(sample_result.instrumental) | |
| if args.bpm is None: | |
| args.bpm = sample_result.bpm | |
| if not args.keyscale: | |
| args.keyscale = sample_result.keyscale | |
| if not args.timesignature: | |
| args.timesignature = sample_result.timesignature | |
| if args.duration <= 0: | |
| args.duration = sample_result.duration | |
| if args.vocal_language in ("unknown", "", None): | |
| args.vocal_language = sample_result.language | |
| args.sample_mode = True | |
| print("✓ Sample created. Using generated parameters.") | |
| else: | |
| parser.error(f"create_sample failed: {sample_result.error or sample_result.status_message}") | |
| # --- Format caption/lyrics if requested --- | |
| if args.use_format and (args.caption or args.lyrics): | |
| if not llm_handler.llm_initialized: | |
| parser.error("--use_format requires the LM handler, but it's not initialized.") | |
| user_metadata_for_format = {} | |
| if args.bpm is not None: | |
| user_metadata_for_format["bpm"] = args.bpm | |
| if args.duration is not None and float(args.duration) > 0: | |
| user_metadata_for_format["duration"] = float(args.duration) | |
| if args.keyscale: | |
| user_metadata_for_format["keyscale"] = args.keyscale | |
| if args.timesignature: | |
| user_metadata_for_format["timesignature"] = args.timesignature | |
| if args.vocal_language and args.vocal_language != "unknown": | |
| user_metadata_for_format["language"] = args.vocal_language | |
| print("\nINFO: Formatting caption/lyrics via 'format_sample'...") | |
| format_result = format_sample( | |
| llm_handler=llm_handler, | |
| caption=args.caption or "", | |
| lyrics=args.lyrics or "", | |
| user_metadata=user_metadata_for_format if user_metadata_for_format else None, | |
| temperature=args.lm_temperature, | |
| top_k=args.lm_top_k, | |
| top_p=args.lm_top_p, | |
| ) | |
| if format_result.success: | |
| args.caption = format_result.caption or args.caption | |
| args.lyrics = format_result.lyrics or args.lyrics | |
| if format_result.duration: | |
| args.duration = format_result.duration | |
| format_has_duration = True | |
| if format_result.bpm: | |
| args.bpm = format_result.bpm | |
| if format_result.keyscale: | |
| args.keyscale = format_result.keyscale | |
| if format_result.timesignature: | |
| args.timesignature = format_result.timesignature | |
| print("✓ Format complete.") | |
| else: | |
| parser.error(f"format_sample failed: {format_result.error or format_result.status_message}") | |
| # --- Auto-generate Lyrics if Requested --- | |
| if args.use_cot_lyrics: | |
| if not llm_handler.llm_initialized: | |
| parser.error("--use_cot_lyrics requires the LM handler, but it's not initialized. Ensure --thinking is enabled.") | |
| print("\nINFO: Generating lyrics and metadata via 'create_sample'...") | |
| sample_result = create_sample( | |
| llm_handler=llm_handler, | |
| query=args.caption, | |
| instrumental=False, | |
| vocal_language=args.vocal_language if args.vocal_language != 'unknown' else None, | |
| temperature=args.lm_temperature, | |
| top_k=args.lm_top_k, | |
| top_p=args.lm_top_p, | |
| ) | |
| if sample_result.success: | |
| print("✓ Automatic sample creation successful. Using generated parameters:") | |
| # Update args with values from create_sample, respecting user-provided values | |
| args.caption = sample_result.caption | |
| args.lyrics = sample_result.lyrics | |
| if args.bpm is None: args.bpm = sample_result.bpm | |
| if not args.keyscale: args.keyscale = sample_result.keyscale | |
| if not args.timesignature: args.timesignature = sample_result.timesignature | |
| if args.duration <= 0: args.duration = sample_result.duration | |
| if args.vocal_language == 'unknown': args.vocal_language = sample_result.language | |
| print(f" - Caption: {args.caption}") | |
| lyrics_preview = args.lyrics[:150].strip().replace("\n", " ") | |
| print(f" - Lyrics: '{lyrics_preview}...'") | |
| print(f" - Metadata: BPM={args.bpm}, Key='{args.keyscale}', Lang='{args.vocal_language}'") | |
| # Disable subsequent CoT steps to avoid redundancy and save time | |
| args.use_cot_metas = False | |
| args.use_cot_caption = False | |
| else: | |
| print(f"⚠️ WARNING: Automatic lyric generation via 'create_sample' failed: {sample_result.error}") | |
| print(" Proceeding with an instrumental track instead.") | |
| args.lyrics = "[Instrumental]" | |
| args.instrumental = True | |
| # Flag has served its purpose, disable it to avoid issues with GenerationParams | |
| args.use_cot_lyrics = False | |
| if args.sample_mode or format_has_duration: | |
| args.use_cot_metas = False | |
| # --- Prompt Editing Hook for LLM Audio Tokens --- | |
| if args.thinking and args.task_type not in skip_lm_tasks: | |
| instruction_path = os.path.join( | |
| os.path.abspath(args.project_root) if args.project_root else os.getcwd(), | |
| "instruction.txt", | |
| ) | |
| preloaded_prompt = None | |
| use_instruction_file = False | |
| if args.config and os.path.exists(instruction_path): | |
| use_instruction_file = True | |
| try: | |
| with open(instruction_path, "r", encoding="utf-8") as f: | |
| preloaded_prompt = f.read() | |
| except Exception as e: | |
| print(f"WARNING: Failed to read {instruction_path}: {e}") | |
| preloaded_prompt = None | |
| use_instruction_file = False | |
| if use_instruction_file: | |
| print(f"INFO: Found {instruction_path}. Using it without editing.") | |
| if preloaded_prompt is not None and not preloaded_prompt.strip(): | |
| preloaded_prompt = None | |
| _install_prompt_edit_hook(llm_handler, instruction_path, preloaded_prompt=preloaded_prompt) | |
| # --- Configure Generation --- | |
| params = GenerationParams( | |
| task_type=args.task_type, | |
| instruction=args.instruction, | |
| reference_audio=args.reference_audio, | |
| src_audio=args.src_audio, | |
| audio_codes=args.audio_codes, | |
| caption=args.caption, | |
| lyrics=args.lyrics, | |
| instrumental=args.instrumental, | |
| vocal_language=args.vocal_language, | |
| bpm=args.bpm, | |
| keyscale=args.keyscale, | |
| timesignature=args.timesignature, | |
| duration=args.duration, | |
| inference_steps=args.inference_steps, | |
| seed=args.seed, | |
| guidance_scale=args.guidance_scale, | |
| use_adg=args.use_adg, | |
| cfg_interval_start=args.cfg_interval_start, | |
| cfg_interval_end=args.cfg_interval_end, | |
| shift=args.shift, | |
| infer_method=args.infer_method, | |
| timesteps=timesteps, | |
| repainting_start=args.repainting_start, | |
| repainting_end=args.repainting_end, | |
| audio_cover_strength=args.audio_cover_strength, | |
| thinking=args.thinking, | |
| lm_temperature=args.lm_temperature, | |
| lm_cfg_scale=args.lm_cfg_scale, | |
| lm_top_k=args.lm_top_k, | |
| lm_top_p=args.lm_top_p, | |
| lm_negative_prompt=args.lm_negative_prompt, | |
| use_cot_metas=args.use_cot_metas, | |
| use_cot_caption=args.use_cot_caption, | |
| use_cot_lyrics=args.use_cot_lyrics, | |
| use_cot_language=args.use_cot_language, | |
| use_constrained_decoding=args.use_constrained_decoding | |
| ) | |
| config = GenerationConfig( | |
| batch_size=args.batch_size, | |
| allow_lm_batch=args.allow_lm_batch, | |
| use_random_seed=args.use_random_seed, | |
| seeds=args.seeds, | |
| lm_batch_chunk_size=args.lm_batch_chunk_size, | |
| constrained_decoding_debug=args.constrained_decoding_debug, | |
| audio_format=args.audio_format | |
| ) | |
| # --- Generate Music --- | |
| log_level = getattr(args, "log_level", "INFO") | |
| log_level_upper = str(log_level).upper() | |
| compact_logs = log_level_upper != "DEBUG" | |
| _print_final_parameters( | |
| args, | |
| params, | |
| config, | |
| params_defaults, | |
| config_defaults, | |
| compact=compact_logs, | |
| resolved_device=device, | |
| ) | |
| print("\n--- Starting Generation ---") | |
| print(f"Caption: \"{params.caption}\"") | |
| print(f"Duration: {params.duration}s | Outputs: {config.batch_size}") | |
| if config.seeds: | |
| print(f"Custom Seeds: {config.seeds}") | |
| print("---------------------------\n") | |
| manual_edit_pipeline = ( | |
| args.thinking | |
| and args.task_type not in skip_lm_tasks | |
| and not (params.audio_codes and str(params.audio_codes).strip()) | |
| ) | |
| lm_time_costs = None | |
| if manual_edit_pipeline: | |
| top_k_value = None if not params.lm_top_k or params.lm_top_k == 0 else int(params.lm_top_k) | |
| top_p_value = None if not params.lm_top_p or params.lm_top_p >= 1.0 else params.lm_top_p | |
| actual_batch_size = config.batch_size if config.batch_size is not None else 1 | |
| seed_for_generation = "" | |
| if config.seeds is not None: | |
| if isinstance(config.seeds, list) and len(config.seeds) > 0: | |
| seed_for_generation = ",".join(str(s) for s in config.seeds) | |
| elif isinstance(config.seeds, int): | |
| seed_for_generation = str(config.seeds) | |
| actual_seed_list, _ = dit_handler.prepare_seeds(actual_batch_size, seed_for_generation, config.use_random_seed) | |
| original_target_duration = params.duration | |
| original_bpm = params.bpm | |
| original_keyscale = params.keyscale | |
| original_timesignature = params.timesignature | |
| original_vocal_language = params.vocal_language | |
| lm_result = None | |
| lm_metadata = {} | |
| edited_caption = None | |
| edited_lyrics = None | |
| edited_instruction = None | |
| edited_metas = {} | |
| lm_time_costs = { | |
| "phase1_time": 0.0, | |
| "phase2_time": 0.0, | |
| "total_time": 0.0, | |
| } | |
| for attempt in range(2): | |
| user_metadata = {} | |
| if params.bpm is not None: | |
| try: | |
| bpm_value = float(params.bpm) | |
| if bpm_value > 0: | |
| user_metadata["bpm"] = int(bpm_value) | |
| except (ValueError, TypeError): | |
| pass | |
| if params.keyscale and params.keyscale.strip() and params.keyscale.strip().lower() not in ["n/a", ""]: | |
| user_metadata["keyscale"] = params.keyscale.strip() | |
| if params.timesignature and params.timesignature.strip() and params.timesignature.strip().lower() not in ["n/a", ""]: | |
| user_metadata["timesignature"] = params.timesignature.strip() | |
| if params.duration is not None: | |
| try: | |
| duration_value = float(params.duration) | |
| if duration_value > 0: | |
| user_metadata["duration"] = int(duration_value) | |
| except (ValueError, TypeError): | |
| pass | |
| # Only include caption and language in user_metadata on | |
| # regeneration attempts. On the first attempt the LM should | |
| # generate/expand these via CoT (matching inference.py behaviour). | |
| if attempt > 0: | |
| if params.caption and params.caption.strip(): | |
| user_metadata["caption"] = params.caption.strip() | |
| if params.vocal_language and params.vocal_language not in ("", "unknown"): | |
| user_metadata["language"] = params.vocal_language | |
| user_metadata_to_pass = user_metadata if user_metadata else None | |
| lm_result = llm_handler.generate_with_stop_condition( | |
| caption=params.caption or "", | |
| lyrics=params.lyrics or "", | |
| infer_type="llm_dit", | |
| temperature=params.lm_temperature, | |
| cfg_scale=params.lm_cfg_scale, | |
| negative_prompt=params.lm_negative_prompt, | |
| top_k=top_k_value, | |
| top_p=top_p_value, | |
| target_duration=params.duration, | |
| user_metadata=user_metadata_to_pass, | |
| use_cot_caption=params.use_cot_caption, | |
| use_cot_language=params.use_cot_language, | |
| use_cot_metas=params.use_cot_metas, | |
| use_constrained_decoding=params.use_constrained_decoding, | |
| constrained_decoding_debug=config.constrained_decoding_debug, | |
| batch_size=actual_batch_size, | |
| seeds=actual_seed_list, | |
| ) | |
| lm_extra_time = (lm_result.get("extra_outputs") or {}).get("time_costs", {}) | |
| if lm_extra_time: | |
| lm_time_costs["phase1_time"] += float(lm_extra_time.get("phase1_time", 0.0) or 0.0) | |
| lm_time_costs["phase2_time"] += float(lm_extra_time.get("phase2_time", 0.0) or 0.0) | |
| lm_time_costs["total_time"] += float( | |
| lm_extra_time.get( | |
| "total_time", | |
| (lm_extra_time.get("phase1_time", 0.0) or 0.0) | |
| + (lm_extra_time.get("phase2_time", 0.0) or 0.0), | |
| ) | |
| or 0.0 | |
| ) | |
| if not lm_result.get("success", False): | |
| error_msg = lm_result.get("error", "Unknown LM error") | |
| print(f"\n❌ Generation failed: {error_msg}") | |
| print(f" Status: {lm_result.get('error', '')}") | |
| return | |
| if actual_batch_size > 1: | |
| lm_metadata = (lm_result.get("metadata") or [{}])[0] | |
| audio_codes = lm_result.get("audio_codes", []) | |
| else: | |
| lm_metadata = lm_result.get("metadata", {}) or {} | |
| audio_codes = lm_result.get("audio_codes", "") | |
| if audio_codes: | |
| params.audio_codes = audio_codes | |
| else: | |
| print("WARNING: LM did not return audio codes; proceeding without codes.") | |
| edited_caption = getattr(llm_handler, "_edited_caption", None) | |
| edited_lyrics = getattr(llm_handler, "_edited_lyrics", None) | |
| edited_instruction = getattr(llm_handler, "_edited_instruction", None) | |
| edited_metas = getattr(llm_handler, "_edited_metas", {}) | |
| parsed_duration = None | |
| parsed_bpm = None | |
| parsed_keyscale = None | |
| parsed_timesignature = None | |
| parsed_language = None | |
| if edited_metas: | |
| bpm_value = edited_metas.get("bpm") | |
| if bpm_value: | |
| parsed = _parse_number(bpm_value) | |
| if parsed is not None and parsed > 0: | |
| parsed_bpm = int(parsed) | |
| duration_value = edited_metas.get("duration") | |
| if duration_value: | |
| parsed = _parse_number(duration_value) | |
| if parsed is not None and parsed > 0: | |
| parsed_duration = float(parsed) | |
| keyscale_value = edited_metas.get("keyscale") | |
| if keyscale_value: | |
| parsed_keyscale = keyscale_value | |
| timesignature_value = edited_metas.get("timesignature") | |
| if timesignature_value: | |
| parsed_timesignature = timesignature_value | |
| language_value = edited_metas.get("language") or edited_metas.get("vocal_language") | |
| if language_value: | |
| parsed_language = language_value | |
| if attempt == 0: | |
| duration_changed = parsed_duration is not None and ( | |
| original_target_duration is None | |
| or float(original_target_duration) <= 0 | |
| or abs(float(original_target_duration) - parsed_duration) > 1e-6 | |
| ) | |
| bpm_changed = parsed_bpm is not None and parsed_bpm != original_bpm | |
| keyscale_changed = parsed_keyscale is not None and parsed_keyscale != original_keyscale | |
| timesignature_changed = parsed_timesignature is not None and parsed_timesignature != original_timesignature | |
| language_changed = parsed_language is not None and parsed_language != original_vocal_language | |
| if duration_changed or bpm_changed or keyscale_changed or timesignature_changed or language_changed: | |
| if duration_changed: | |
| params.duration = parsed_duration | |
| if bpm_changed: | |
| params.bpm = parsed_bpm | |
| if keyscale_changed: | |
| params.keyscale = parsed_keyscale | |
| if timesignature_changed: | |
| params.timesignature = parsed_timesignature | |
| if language_changed: | |
| params.vocal_language = parsed_language | |
| # Carry forward the expanded caption so the second | |
| # attempt's <think> block (and user_metadata) use it | |
| # instead of the short original caption. | |
| edited_caption_for_regen = edited_metas.get("caption") if edited_metas else None | |
| if edited_caption_for_regen and edited_caption_for_regen.strip(): | |
| params.caption = edited_caption_for_regen | |
| print("INFO: Edited metadata detected. Regenerating audio codes with updated values.") | |
| llm_handler._skip_prompt_edit = True | |
| continue | |
| break | |
| edited_meta_caption = edited_metas.get("caption") if edited_metas else None | |
| if edited_meta_caption and edited_meta_caption.strip(): | |
| params.caption = edited_meta_caption | |
| elif edited_caption: | |
| params.caption = edited_caption | |
| elif params.use_cot_caption and lm_metadata.get("caption"): | |
| params.caption = lm_metadata.get("caption") | |
| if edited_lyrics: | |
| params.lyrics = edited_lyrics | |
| elif not params.lyrics and lm_metadata.get("lyrics"): | |
| params.lyrics = lm_metadata.get("lyrics") | |
| if edited_instruction: | |
| params.instruction = edited_instruction | |
| if edited_metas: | |
| bpm_value = edited_metas.get("bpm") | |
| if bpm_value: | |
| parsed = _parse_number(bpm_value) | |
| if parsed is not None: | |
| params.bpm = int(parsed) | |
| duration_value = edited_metas.get("duration") | |
| if duration_value: | |
| parsed = _parse_number(duration_value) | |
| if parsed is not None: | |
| params.duration = float(parsed) | |
| keyscale_value = edited_metas.get("keyscale") | |
| if keyscale_value: | |
| params.keyscale = keyscale_value | |
| timesignature_value = edited_metas.get("timesignature") | |
| if timesignature_value: | |
| params.timesignature = timesignature_value | |
| language_value = edited_metas.get("language") or edited_metas.get("vocal_language") | |
| if language_value: | |
| params.vocal_language = language_value | |
| else: | |
| if params.bpm is None and lm_metadata.get("bpm") not in (None, "N/A", ""): | |
| parsed = _parse_number(str(lm_metadata.get("bpm"))) | |
| if parsed is not None: | |
| params.bpm = int(parsed) | |
| if not params.keyscale and lm_metadata.get("keyscale"): | |
| params.keyscale = lm_metadata.get("keyscale") | |
| if not params.timesignature and lm_metadata.get("timesignature"): | |
| params.timesignature = lm_metadata.get("timesignature") | |
| if params.duration is None and lm_metadata.get("duration") not in (None, "N/A", ""): | |
| parsed = _parse_number(str(lm_metadata.get("duration"))) | |
| if parsed is not None: | |
| params.duration = float(parsed) | |
| if params.vocal_language in (None, "", "unknown"): | |
| language_value = lm_metadata.get("vocal_language") or lm_metadata.get("language") | |
| if language_value: | |
| params.vocal_language = language_value | |
| # use_cot_language: override vocal_language with LM detection unless | |
| # the user explicitly edited the language in the think block. | |
| if params.use_cot_language: | |
| edited_lang = (edited_metas.get("language") or edited_metas.get("vocal_language")) if edited_metas else None | |
| if not edited_lang: | |
| lm_lang = lm_metadata.get("vocal_language") or lm_metadata.get("language") | |
| if lm_lang: | |
| params.vocal_language = lm_lang | |
| # Populate cot_* fields for downstream reporting (mirrors inference.py) | |
| if lm_metadata: | |
| if original_bpm is None: | |
| params.cot_bpm = params.bpm | |
| if not original_keyscale: | |
| params.cot_keyscale = params.keyscale | |
| if not original_timesignature: | |
| params.cot_timesignature = params.timesignature | |
| if original_target_duration is None or float(original_target_duration) <= 0: | |
| params.cot_duration = params.duration | |
| if original_vocal_language in (None, "", "unknown"): | |
| params.cot_vocal_language = params.vocal_language | |
| if not params.caption: | |
| params.cot_caption = lm_metadata.get("caption", "") | |
| if not params.lyrics: | |
| params.cot_lyrics = lm_metadata.get("lyrics", "") | |
| params.thinking = False | |
| params.use_cot_caption = False | |
| params.use_cot_language = False | |
| params.use_cot_metas = False | |
| if hasattr(llm_handler, "_skip_prompt_edit"): | |
| llm_handler._skip_prompt_edit = False | |
| if log_level_upper in {"INFO", "DEBUG"}: | |
| _print_dit_prompt(dit_handler, params) | |
| print("Running DiT generation with edited prompt and cached audio codes...") | |
| result = generate_music(dit_handler, llm_handler, params, config, save_dir=args.save_dir) | |
| else: | |
| if log_level_upper in {"INFO", "DEBUG"}: | |
| _print_dit_prompt(dit_handler, params) | |
| result = generate_music(dit_handler, llm_handler, params, config, save_dir=args.save_dir) | |
| # --- Process Results --- | |
| if result.success: | |
| print(f"\n✅ Generation successful! {len(result.audios)} audio(s) saved in '{args.save_dir}/'") | |
| for i, audio in enumerate(result.audios): | |
| print(f" [{i+1}] Path: {audio['path']} | Seed: {audio['params']['seed']}") | |
| time_costs = result.extra_outputs.get("time_costs", {}) | |
| if manual_edit_pipeline and lm_time_costs and time_costs is not None: | |
| if not isinstance(time_costs, dict): | |
| time_costs = {} | |
| result.extra_outputs["time_costs"] = time_costs | |
| if lm_time_costs["total_time"] > 0.0: | |
| time_costs["lm_phase1_time"] = lm_time_costs["phase1_time"] | |
| time_costs["lm_phase2_time"] = lm_time_costs["phase2_time"] | |
| time_costs["lm_total_time"] = lm_time_costs["total_time"] | |
| dit_total = float(time_costs.get("dit_total_time_cost", 0.0) or 0.0) | |
| time_costs["pipeline_total_time"] = time_costs["lm_total_time"] + dit_total | |
| if time_costs: | |
| print("\n--- Performance ---") | |
| total_time = time_costs.get('pipeline_total_time', 0) | |
| print(f"Total time: {total_time:.2f}s") | |
| if args.thinking: | |
| lm1_time = time_costs.get('lm_phase1_time', 0) | |
| lm2_time = time_costs.get('lm_phase2_time', 0) | |
| print(f" - LM time: {lm1_time + lm2_time:.2f}s") | |
| dit_time = time_costs.get('dit_total_time_cost', 0) | |
| print(f" - DiT time: {dit_time:.2f}s") | |
| print("-------------------\n") | |
| else: | |
| print(f"\n❌ Generation failed: {result.error}") | |
| print(f" Status: {result.status_message}") | |
| if __name__ == "__main__": | |
| main() | |