File size: 5,468 Bytes
8bdd018 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 | #!/usr/bin/env python
"""
Send one audio file to an Audio Flamingo 3 endpoint and print/save the response.
"""
from __future__ import annotations
import argparse
import base64
import io
import json
import sys
from pathlib import Path
from urllib.error import HTTPError, URLError
from urllib.request import Request, urlopen
import soundfile as sf
PROJECT_ROOT = Path(__file__).resolve().parents[2]
if str(PROJECT_ROOT) not in sys.path:
sys.path.insert(0, str(PROJECT_ROOT))
from af3_chatgpt_pipeline import DEFAULT_AF3_PROMPT, DEFAULT_AF3_PROMPT_THINK_LONG
from qwen_audio_captioning import load_audio_mono
from utils.env_config import get_env, load_project_env
def load_audio_b64(audio_path: str, target_sr: int = 16000) -> str:
mono, sr = load_audio_mono(audio_path, target_sr=target_sr)
buf = io.BytesIO()
sf.write(buf, mono, int(sr), format="WAV")
return base64.b64encode(buf.getvalue()).decode("utf-8")
def send(url: str, token: str, payload: dict) -> dict:
req = Request(
url=url,
method="POST",
data=json.dumps(payload).encode("utf-8"),
headers={
**({"Authorization": f"Bearer {token}"} if token else {}),
"Content-Type": "application/json",
},
)
try:
with urlopen(req, timeout=600) as resp:
text = resp.read().decode("utf-8")
return json.loads(text)
except HTTPError as e:
body = e.read().decode("utf-8", errors="replace")
lower = body.lower()
if "endpoint is in error" in lower:
body += (
"\nHint: open the endpoint page and restart/redeploy. "
"This is a remote runtime failure, not a local script issue."
)
if "no custom pipeline found" in lower:
body += (
"\nHint: endpoint repo root must contain handler.py; "
"ensure you deployed templates/hf-af3-caption-endpoint files."
)
if "audioflamingo3" in lower and "does not recognize" in lower:
body += (
"\nHint: runtime transformers is too old. "
"Use templates/hf-af3-caption-endpoint/handler.py bootstrap runtime "
"(AF3_TRANSFORMERS_SPEC=transformers==5.1.0) and redeploy."
)
if "failed to load af3 processor classes after runtime bootstrap" in lower:
body += (
"\nHint: endpoint startup could not install/load AF3 runtime deps. "
"Check startup logs for pip/network/disk issues and keep task=custom."
)
raise RuntimeError(f"HTTP {e.code}: {body}") from e
except URLError as e:
raise RuntimeError(f"Network error: {e}") from e
def main() -> int:
load_project_env()
parser = argparse.ArgumentParser(description="Test AF3 caption endpoint")
parser.add_argument(
"--url",
default=get_env("HF_AF3_ENDPOINT_URL", "hf_af3_endpoint_url"),
required=False,
)
parser.add_argument(
"--token",
default=get_env("HF_TOKEN", "hf_token"),
required=False,
)
parser.add_argument("--audio", required=True, help="Path to local audio file")
parser.add_argument("--prompt", default=DEFAULT_AF3_PROMPT)
parser.add_argument(
"--mode",
choices=["auto", "think", "single"],
default="auto",
help="Optional AF3 mode selector for NVIDIA-stack endpoints.",
)
parser.add_argument(
"--think-long",
action="store_true",
help="Use long-form AF3 prompt + higher token budget defaults.",
)
parser.add_argument("--max-new-tokens", type=int, default=1400)
parser.add_argument("--temperature", type=float, default=0.1)
parser.add_argument("--save-json", default="", help="Optional output JSON path")
args = parser.parse_args()
if not args.url:
raise RuntimeError("Missing endpoint URL. Pass --url or set HF_AF3_ENDPOINT_URL.")
if not Path(args.audio).is_file():
raise FileNotFoundError(f"Audio file not found: {args.audio}")
audio_b64 = load_audio_b64(args.audio, target_sr=16000)
prompt = args.prompt
max_new_tokens = int(args.max_new_tokens)
temperature = float(args.temperature)
if args.think_long:
if prompt == DEFAULT_AF3_PROMPT:
prompt = DEFAULT_AF3_PROMPT_THINK_LONG
if max_new_tokens == 1400:
max_new_tokens = 3200
if abs(temperature - 0.1) < 1e-9:
temperature = 0.2
payload = {
"inputs": {
"prompt": prompt,
"audio_base64": audio_b64,
"sample_rate": 16000,
"max_new_tokens": max_new_tokens,
"temperature": temperature,
}
}
if args.mode != "auto":
payload["inputs"]["think_mode"] = bool(args.mode == "think")
result = send(args.url, args.token, payload)
rendered = json.dumps(result, indent=2, ensure_ascii=False)
try:
print(rendered)
except UnicodeEncodeError:
# Fallback for Windows cp1252 terminals when model emits non-ASCII punctuation.
print(json.dumps(result, indent=2, ensure_ascii=True))
if args.save_json:
Path(args.save_json).write_text(
json.dumps(result, indent=2, ensure_ascii=False),
encoding="utf-8",
)
print(f"Saved: {args.save_json}")
return 0
if __name__ == "__main__":
raise SystemExit(main())
|