import argparse import base64 import json import os import sys import time from pathlib import Path from urllib.error import HTTPError, URLError from urllib.request import Request, urlopen DEFAULT_URL = "https://your-endpoint-url.endpoints.huggingface.cloud" DEFAULT_SAMPLE_RATE = 44100 def read_dotenv_value(key: str, dotenv_path: str = ".env") -> str: path = Path(dotenv_path) if not path.exists(): return "" for raw in path.read_text(encoding="utf-8").splitlines(): line = raw.strip() if not line or line.startswith("#") or "=" not in line: continue k, v = line.split("=", 1) if k.strip() == key: return v.strip().strip('"').strip("'") return "" def prompt_text(label: str, default: str = "", required: bool = False) -> str: while True: suffix = f" [{default}]" if default else "" value = input(f"{label}{suffix}: ").strip() if not value: value = default if value or not required: return value print("Value required.") def prompt_int(label: str, default: int | None = None, allow_blank: bool = False) -> int | None: while True: default_str = "" if default is None else str(default) value = prompt_text(label, default_str, required=not allow_blank) if not value and allow_blank: return None try: return int(value) except ValueError: print("Enter a valid integer.") def prompt_float(label: str, default: float) -> float: while True: value = prompt_text(label, str(default), required=True) try: return float(value) except ValueError: print("Enter a valid number.") def prompt_yes_no(label: str, default: bool) -> bool: default_text = "y" if default else "n" while True: value = prompt_text(f"{label} (y/n)", default_text, required=True).lower() if value in {"y", "yes", "1", "true", "t"}: return True if value in {"n", "no", "0", "false", "f"}: return False print("Please answer y or n.") def prompt_multiline(label: str, end_token: str = "END") -> str: print(label) print(f"Finish lyrics by typing {end_token} on its own line.") lines: list[str] = [] while True: line = input() if line.strip() == end_token: break lines.append(line) return "\n".join(lines).strip() def prompt_lyrics_optional() -> str: use_lyrics = prompt_yes_no("Add custom lyrics", True) if not use_lyrics: return "" return prompt_multiline("Paste lyrics (or just type END for none)") def send_request(url: str, token: str, payload: dict) -> dict: data = json.dumps(payload).encode("utf-8") req = Request( url=url, data=data, method="POST", headers={ "Authorization": f"Bearer {token}", "Content-Type": "application/json", }, ) try: with urlopen(req, timeout=3600) as resp: body = resp.read().decode("utf-8") return json.loads(body) except HTTPError as e: text = e.read().decode("utf-8", errors="replace") raise RuntimeError(f"HTTP {e.code}: {text}") from e except URLError as e: raise RuntimeError(f"Network error: {e}") from e def resolve_token(cli_token: str) -> str: if cli_token: return cli_token env_token = os.getenv("HF_TOKEN") or os.getenv("hf_token") if env_token: return env_token dotenv_token = read_dotenv_value("hf_token") or read_dotenv_value("HF_TOKEN") return dotenv_token def main() -> int: parser = argparse.ArgumentParser(description="Interactive ACE-Step endpoint generator") parser.add_argument("--url", default=os.getenv("HF_ENDPOINT_URL", DEFAULT_URL), help="Inference endpoint URL") parser.add_argument("--token", default="", help="HF token. If omitted, uses env/.env") parser.add_argument("--prompt", default="", help="Initial prompt") parser.add_argument("--out-file", default="", help="Output WAV file path") parser.add_argument( "--advanced", action="store_true", help="Ask advanced generation options (seed/guidance/steps/sample-rate/LM).", ) args = parser.parse_args() print("=== ACE-Step Interactive Generation ===") token = resolve_token(args.token) if not token: print("No token found. Set HF_TOKEN or hf_token in .env, or pass --token.") return 1 url = prompt_text("Endpoint URL", args.url, required=True) music_prompt = prompt_text("Music prompt", args.prompt, required=True) bpm = prompt_int("BPM (blank for auto)", None, allow_blank=True) duration_sec = prompt_int("Duration seconds", 120) instrumental = prompt_yes_no("Instrumental (no vocals)", False) lyrics = "" if instrumental else prompt_lyrics_optional() # Quality-first defaults: use SFT + LM path configured on the endpoint. sample_rate = DEFAULT_SAMPLE_RATE steps = 50 guidance_scale = 7.0 seed = 42 use_lm = True allow_fallback = False simple_prompt = False if args.advanced: print("\nAdvanced options:") sample_rate = prompt_int("Sample rate", DEFAULT_SAMPLE_RATE) steps = prompt_int("Steps", 50) guidance_scale = prompt_float("Guidance scale", 7.0) seed = prompt_int("Seed", 42) use_lm = prompt_yes_no("Use LM planning (higher quality, slower)", True) allow_fallback = prompt_yes_no("Allow fallback sine audio", False) default_out = args.out_file or f"music_{int(time.time())}.wav" out_file = prompt_text("Output file", default_out, required=True) inputs = { "prompt": music_prompt, "duration_sec": duration_sec, "sample_rate": sample_rate, "seed": seed, "guidance_scale": guidance_scale, "steps": steps, "use_lm": use_lm, "simple_prompt": simple_prompt, "instrumental": instrumental, "allow_fallback": allow_fallback, } if bpm is not None: inputs["bpm"] = bpm if lyrics: inputs["lyrics"] = lyrics payload = {"inputs": inputs} print("\nSending request...") try: response = send_request(url, token, payload) except Exception as e: print(f"Request failed: {e}") return 1 print("Response summary:") print(json.dumps({ "used_fallback": response.get("used_fallback"), "model_loaded": response.get("model_loaded"), "model_error": response.get("model_error"), "sample_rate": response.get("sample_rate"), "duration_sec": response.get("duration_sec"), }, indent=2)) if response.get("error"): print(f"Endpoint error: {response['error']}") return 1 audio_b64 = response.get("audio_base64_wav") if not audio_b64: print("No audio_base64_wav in response.") return 1 audio_bytes = base64.b64decode(audio_b64) Path(out_file).write_bytes(audio_bytes) print(f"Saved audio: {out_file}") return 0 if __name__ == "__main__": raise SystemExit(main())