File size: 7,212 Bytes
bd37cca | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 | import argparse
import base64
import json
import os
import sys
import time
from pathlib import Path
from urllib.error import HTTPError, URLError
from urllib.request import Request, urlopen
DEFAULT_URL = "https://your-endpoint-url.endpoints.huggingface.cloud"
DEFAULT_SAMPLE_RATE = 44100
def read_dotenv_value(key: str, dotenv_path: str = ".env") -> str:
path = Path(dotenv_path)
if not path.exists():
return ""
for raw in path.read_text(encoding="utf-8").splitlines():
line = raw.strip()
if not line or line.startswith("#") or "=" not in line:
continue
k, v = line.split("=", 1)
if k.strip() == key:
return v.strip().strip('"').strip("'")
return ""
def prompt_text(label: str, default: str = "", required: bool = False) -> str:
while True:
suffix = f" [{default}]" if default else ""
value = input(f"{label}{suffix}: ").strip()
if not value:
value = default
if value or not required:
return value
print("Value required.")
def prompt_int(label: str, default: int | None = None, allow_blank: bool = False) -> int | None:
while True:
default_str = "" if default is None else str(default)
value = prompt_text(label, default_str, required=not allow_blank)
if not value and allow_blank:
return None
try:
return int(value)
except ValueError:
print("Enter a valid integer.")
def prompt_float(label: str, default: float) -> float:
while True:
value = prompt_text(label, str(default), required=True)
try:
return float(value)
except ValueError:
print("Enter a valid number.")
def prompt_yes_no(label: str, default: bool) -> bool:
default_text = "y" if default else "n"
while True:
value = prompt_text(f"{label} (y/n)", default_text, required=True).lower()
if value in {"y", "yes", "1", "true", "t"}:
return True
if value in {"n", "no", "0", "false", "f"}:
return False
print("Please answer y or n.")
def prompt_multiline(label: str, end_token: str = "END") -> str:
print(label)
print(f"Finish lyrics by typing {end_token} on its own line.")
lines: list[str] = []
while True:
line = input()
if line.strip() == end_token:
break
lines.append(line)
return "\n".join(lines).strip()
def prompt_lyrics_optional() -> str:
use_lyrics = prompt_yes_no("Add custom lyrics", True)
if not use_lyrics:
return ""
return prompt_multiline("Paste lyrics (or just type END for none)")
def send_request(url: str, token: str, payload: dict) -> dict:
data = json.dumps(payload).encode("utf-8")
req = Request(
url=url,
data=data,
method="POST",
headers={
"Authorization": f"Bearer {token}",
"Content-Type": "application/json",
},
)
try:
with urlopen(req, timeout=3600) as resp:
body = resp.read().decode("utf-8")
return json.loads(body)
except HTTPError as e:
text = e.read().decode("utf-8", errors="replace")
raise RuntimeError(f"HTTP {e.code}: {text}") from e
except URLError as e:
raise RuntimeError(f"Network error: {e}") from e
def resolve_token(cli_token: str) -> str:
if cli_token:
return cli_token
env_token = os.getenv("HF_TOKEN") or os.getenv("hf_token")
if env_token:
return env_token
dotenv_token = read_dotenv_value("hf_token") or read_dotenv_value("HF_TOKEN")
return dotenv_token
def main() -> int:
parser = argparse.ArgumentParser(description="Interactive ACE-Step endpoint generator")
parser.add_argument("--url", default=os.getenv("HF_ENDPOINT_URL", DEFAULT_URL), help="Inference endpoint URL")
parser.add_argument("--token", default="", help="HF token. If omitted, uses env/.env")
parser.add_argument("--prompt", default="", help="Initial prompt")
parser.add_argument("--out-file", default="", help="Output WAV file path")
parser.add_argument(
"--advanced",
action="store_true",
help="Ask advanced generation options (seed/guidance/steps/sample-rate/LM).",
)
args = parser.parse_args()
print("=== ACE-Step Interactive Generation ===")
token = resolve_token(args.token)
if not token:
print("No token found. Set HF_TOKEN or hf_token in .env, or pass --token.")
return 1
url = prompt_text("Endpoint URL", args.url, required=True)
music_prompt = prompt_text("Music prompt", args.prompt, required=True)
bpm = prompt_int("BPM (blank for auto)", None, allow_blank=True)
duration_sec = prompt_int("Duration seconds", 120)
instrumental = prompt_yes_no("Instrumental (no vocals)", False)
lyrics = "" if instrumental else prompt_lyrics_optional()
# Quality-first defaults: use SFT + LM path configured on the endpoint.
sample_rate = DEFAULT_SAMPLE_RATE
steps = 50
guidance_scale = 7.0
seed = 42
use_lm = True
allow_fallback = False
simple_prompt = False
if args.advanced:
print("\nAdvanced options:")
sample_rate = prompt_int("Sample rate", DEFAULT_SAMPLE_RATE)
steps = prompt_int("Steps", 50)
guidance_scale = prompt_float("Guidance scale", 7.0)
seed = prompt_int("Seed", 42)
use_lm = prompt_yes_no("Use LM planning (higher quality, slower)", True)
allow_fallback = prompt_yes_no("Allow fallback sine audio", False)
default_out = args.out_file or f"music_{int(time.time())}.wav"
out_file = prompt_text("Output file", default_out, required=True)
inputs = {
"prompt": music_prompt,
"duration_sec": duration_sec,
"sample_rate": sample_rate,
"seed": seed,
"guidance_scale": guidance_scale,
"steps": steps,
"use_lm": use_lm,
"simple_prompt": simple_prompt,
"instrumental": instrumental,
"allow_fallback": allow_fallback,
}
if bpm is not None:
inputs["bpm"] = bpm
if lyrics:
inputs["lyrics"] = lyrics
payload = {"inputs": inputs}
print("\nSending request...")
try:
response = send_request(url, token, payload)
except Exception as e:
print(f"Request failed: {e}")
return 1
print("Response summary:")
print(json.dumps({
"used_fallback": response.get("used_fallback"),
"model_loaded": response.get("model_loaded"),
"model_error": response.get("model_error"),
"sample_rate": response.get("sample_rate"),
"duration_sec": response.get("duration_sec"),
}, indent=2))
if response.get("error"):
print(f"Endpoint error: {response['error']}")
return 1
audio_b64 = response.get("audio_base64_wav")
if not audio_b64:
print("No audio_base64_wav in response.")
return 1
audio_bytes = base64.b64decode(audio_b64)
Path(out_file).write_bytes(audio_bytes)
print(f"Saved audio: {out_file}")
return 0
if __name__ == "__main__":
raise SystemExit(main())
|