ace-step15-endpoint / scripts /endpoint /generate_interactive.py
Andrew
github push
bd37cca
import argparse
import base64
import json
import os
import sys
import time
from pathlib import Path
from urllib.error import HTTPError, URLError
from urllib.request import Request, urlopen
DEFAULT_URL = "https://your-endpoint-url.endpoints.huggingface.cloud"
DEFAULT_SAMPLE_RATE = 44100
def read_dotenv_value(key: str, dotenv_path: str = ".env") -> str:
path = Path(dotenv_path)
if not path.exists():
return ""
for raw in path.read_text(encoding="utf-8").splitlines():
line = raw.strip()
if not line or line.startswith("#") or "=" not in line:
continue
k, v = line.split("=", 1)
if k.strip() == key:
return v.strip().strip('"').strip("'")
return ""
def prompt_text(label: str, default: str = "", required: bool = False) -> str:
while True:
suffix = f" [{default}]" if default else ""
value = input(f"{label}{suffix}: ").strip()
if not value:
value = default
if value or not required:
return value
print("Value required.")
def prompt_int(label: str, default: int | None = None, allow_blank: bool = False) -> int | None:
while True:
default_str = "" if default is None else str(default)
value = prompt_text(label, default_str, required=not allow_blank)
if not value and allow_blank:
return None
try:
return int(value)
except ValueError:
print("Enter a valid integer.")
def prompt_float(label: str, default: float) -> float:
while True:
value = prompt_text(label, str(default), required=True)
try:
return float(value)
except ValueError:
print("Enter a valid number.")
def prompt_yes_no(label: str, default: bool) -> bool:
default_text = "y" if default else "n"
while True:
value = prompt_text(f"{label} (y/n)", default_text, required=True).lower()
if value in {"y", "yes", "1", "true", "t"}:
return True
if value in {"n", "no", "0", "false", "f"}:
return False
print("Please answer y or n.")
def prompt_multiline(label: str, end_token: str = "END") -> str:
print(label)
print(f"Finish lyrics by typing {end_token} on its own line.")
lines: list[str] = []
while True:
line = input()
if line.strip() == end_token:
break
lines.append(line)
return "\n".join(lines).strip()
def prompt_lyrics_optional() -> str:
use_lyrics = prompt_yes_no("Add custom lyrics", True)
if not use_lyrics:
return ""
return prompt_multiline("Paste lyrics (or just type END for none)")
def send_request(url: str, token: str, payload: dict) -> dict:
data = json.dumps(payload).encode("utf-8")
req = Request(
url=url,
data=data,
method="POST",
headers={
"Authorization": f"Bearer {token}",
"Content-Type": "application/json",
},
)
try:
with urlopen(req, timeout=3600) as resp:
body = resp.read().decode("utf-8")
return json.loads(body)
except HTTPError as e:
text = e.read().decode("utf-8", errors="replace")
raise RuntimeError(f"HTTP {e.code}: {text}") from e
except URLError as e:
raise RuntimeError(f"Network error: {e}") from e
def resolve_token(cli_token: str) -> str:
if cli_token:
return cli_token
env_token = os.getenv("HF_TOKEN") or os.getenv("hf_token")
if env_token:
return env_token
dotenv_token = read_dotenv_value("hf_token") or read_dotenv_value("HF_TOKEN")
return dotenv_token
def main() -> int:
parser = argparse.ArgumentParser(description="Interactive ACE-Step endpoint generator")
parser.add_argument("--url", default=os.getenv("HF_ENDPOINT_URL", DEFAULT_URL), help="Inference endpoint URL")
parser.add_argument("--token", default="", help="HF token. If omitted, uses env/.env")
parser.add_argument("--prompt", default="", help="Initial prompt")
parser.add_argument("--out-file", default="", help="Output WAV file path")
parser.add_argument(
"--advanced",
action="store_true",
help="Ask advanced generation options (seed/guidance/steps/sample-rate/LM).",
)
args = parser.parse_args()
print("=== ACE-Step Interactive Generation ===")
token = resolve_token(args.token)
if not token:
print("No token found. Set HF_TOKEN or hf_token in .env, or pass --token.")
return 1
url = prompt_text("Endpoint URL", args.url, required=True)
music_prompt = prompt_text("Music prompt", args.prompt, required=True)
bpm = prompt_int("BPM (blank for auto)", None, allow_blank=True)
duration_sec = prompt_int("Duration seconds", 120)
instrumental = prompt_yes_no("Instrumental (no vocals)", False)
lyrics = "" if instrumental else prompt_lyrics_optional()
# Quality-first defaults: use SFT + LM path configured on the endpoint.
sample_rate = DEFAULT_SAMPLE_RATE
steps = 50
guidance_scale = 7.0
seed = 42
use_lm = True
allow_fallback = False
simple_prompt = False
if args.advanced:
print("\nAdvanced options:")
sample_rate = prompt_int("Sample rate", DEFAULT_SAMPLE_RATE)
steps = prompt_int("Steps", 50)
guidance_scale = prompt_float("Guidance scale", 7.0)
seed = prompt_int("Seed", 42)
use_lm = prompt_yes_no("Use LM planning (higher quality, slower)", True)
allow_fallback = prompt_yes_no("Allow fallback sine audio", False)
default_out = args.out_file or f"music_{int(time.time())}.wav"
out_file = prompt_text("Output file", default_out, required=True)
inputs = {
"prompt": music_prompt,
"duration_sec": duration_sec,
"sample_rate": sample_rate,
"seed": seed,
"guidance_scale": guidance_scale,
"steps": steps,
"use_lm": use_lm,
"simple_prompt": simple_prompt,
"instrumental": instrumental,
"allow_fallback": allow_fallback,
}
if bpm is not None:
inputs["bpm"] = bpm
if lyrics:
inputs["lyrics"] = lyrics
payload = {"inputs": inputs}
print("\nSending request...")
try:
response = send_request(url, token, payload)
except Exception as e:
print(f"Request failed: {e}")
return 1
print("Response summary:")
print(json.dumps({
"used_fallback": response.get("used_fallback"),
"model_loaded": response.get("model_loaded"),
"model_error": response.get("model_error"),
"sample_rate": response.get("sample_rate"),
"duration_sec": response.get("duration_sec"),
}, indent=2))
if response.get("error"):
print(f"Endpoint error: {response['error']}")
return 1
audio_b64 = response.get("audio_base64_wav")
if not audio_b64:
print("No audio_base64_wav in response.")
return 1
audio_bytes = base64.b64decode(audio_b64)
Path(out_file).write_bytes(audio_bytes)
print(f"Saved audio: {out_file}")
return 0
if __name__ == "__main__":
raise SystemExit(main())