Spaces:
Running on Zero
Running on Zero
| #!/usr/bin/env python | |
| """ | |
| Send one audio file to an Audio Flamingo 3 endpoint and print/save the response. | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import base64 | |
| import io | |
| import json | |
| import sys | |
| from pathlib import Path | |
| from urllib.error import HTTPError, URLError | |
| from urllib.request import Request, urlopen | |
| import soundfile as sf | |
| PROJECT_ROOT = Path(__file__).resolve().parents[2] | |
| if str(PROJECT_ROOT) not in sys.path: | |
| sys.path.insert(0, str(PROJECT_ROOT)) | |
| from af3_chatgpt_pipeline import DEFAULT_AF3_PROMPT, DEFAULT_AF3_PROMPT_THINK_LONG | |
| from qwen_audio_captioning import load_audio_mono | |
| from utils.env_config import get_env, load_project_env | |
| def load_audio_b64(audio_path: str, target_sr: int = 16000) -> str: | |
| mono, sr = load_audio_mono(audio_path, target_sr=target_sr) | |
| buf = io.BytesIO() | |
| sf.write(buf, mono, int(sr), format="WAV") | |
| return base64.b64encode(buf.getvalue()).decode("utf-8") | |
| def send(url: str, token: str, payload: dict) -> dict: | |
| req = Request( | |
| url=url, | |
| method="POST", | |
| data=json.dumps(payload).encode("utf-8"), | |
| headers={ | |
| **({"Authorization": f"Bearer {token}"} if token else {}), | |
| "Content-Type": "application/json", | |
| }, | |
| ) | |
| try: | |
| with urlopen(req, timeout=600) as resp: | |
| text = resp.read().decode("utf-8") | |
| return json.loads(text) | |
| except HTTPError as e: | |
| body = e.read().decode("utf-8", errors="replace") | |
| lower = body.lower() | |
| if "endpoint is in error" in lower: | |
| body += ( | |
| "\nHint: open the endpoint page and restart/redeploy. " | |
| "This is a remote runtime failure, not a local script issue." | |
| ) | |
| if "no custom pipeline found" in lower: | |
| body += ( | |
| "\nHint: endpoint repo root must contain handler.py; " | |
| "ensure you deployed templates/hf-af3-caption-endpoint files." | |
| ) | |
| if "audioflamingo3" in lower and "does not recognize" in lower: | |
| body += ( | |
| "\nHint: runtime transformers is too old. " | |
| "Use templates/hf-af3-caption-endpoint/handler.py bootstrap runtime " | |
| "(AF3_TRANSFORMERS_SPEC=transformers==5.1.0) and redeploy." | |
| ) | |
| if "failed to load af3 processor classes after runtime bootstrap" in lower: | |
| body += ( | |
| "\nHint: endpoint startup could not install/load AF3 runtime deps. " | |
| "Check startup logs for pip/network/disk issues and keep task=custom." | |
| ) | |
| raise RuntimeError(f"HTTP {e.code}: {body}") from e | |
| except URLError as e: | |
| raise RuntimeError(f"Network error: {e}") from e | |
| def main() -> int: | |
| load_project_env() | |
| parser = argparse.ArgumentParser(description="Test AF3 caption endpoint") | |
| parser.add_argument( | |
| "--url", | |
| default=get_env("HF_AF3_ENDPOINT_URL", "hf_af3_endpoint_url"), | |
| required=False, | |
| ) | |
| parser.add_argument( | |
| "--token", | |
| default=get_env("HF_TOKEN", "hf_token"), | |
| required=False, | |
| ) | |
| parser.add_argument("--audio", required=True, help="Path to local audio file") | |
| parser.add_argument("--prompt", default=DEFAULT_AF3_PROMPT) | |
| parser.add_argument( | |
| "--mode", | |
| choices=["auto", "think", "single"], | |
| default="auto", | |
| help="Optional AF3 mode selector for NVIDIA-stack endpoints.", | |
| ) | |
| parser.add_argument( | |
| "--think-long", | |
| action="store_true", | |
| help="Use long-form AF3 prompt + higher token budget defaults.", | |
| ) | |
| parser.add_argument("--max-new-tokens", type=int, default=1400) | |
| parser.add_argument("--temperature", type=float, default=0.1) | |
| parser.add_argument("--save-json", default="", help="Optional output JSON path") | |
| args = parser.parse_args() | |
| if not args.url: | |
| raise RuntimeError("Missing endpoint URL. Pass --url or set HF_AF3_ENDPOINT_URL.") | |
| if not Path(args.audio).is_file(): | |
| raise FileNotFoundError(f"Audio file not found: {args.audio}") | |
| audio_b64 = load_audio_b64(args.audio, target_sr=16000) | |
| prompt = args.prompt | |
| max_new_tokens = int(args.max_new_tokens) | |
| temperature = float(args.temperature) | |
| if args.think_long: | |
| if prompt == DEFAULT_AF3_PROMPT: | |
| prompt = DEFAULT_AF3_PROMPT_THINK_LONG | |
| if max_new_tokens == 1400: | |
| max_new_tokens = 3200 | |
| if abs(temperature - 0.1) < 1e-9: | |
| temperature = 0.2 | |
| payload = { | |
| "inputs": { | |
| "prompt": prompt, | |
| "audio_base64": audio_b64, | |
| "sample_rate": 16000, | |
| "max_new_tokens": max_new_tokens, | |
| "temperature": temperature, | |
| } | |
| } | |
| if args.mode != "auto": | |
| payload["inputs"]["think_mode"] = bool(args.mode == "think") | |
| result = send(args.url, args.token, payload) | |
| rendered = json.dumps(result, indent=2, ensure_ascii=False) | |
| try: | |
| print(rendered) | |
| except UnicodeEncodeError: | |
| # Fallback for Windows cp1252 terminals when model emits non-ASCII punctuation. | |
| print(json.dumps(result, indent=2, ensure_ascii=True)) | |
| if args.save_json: | |
| Path(args.save_json).write_text( | |
| json.dumps(result, indent=2, ensure_ascii=False), | |
| encoding="utf-8", | |
| ) | |
| print(f"Saved: {args.save_json}") | |
| return 0 | |
| if __name__ == "__main__": | |
| raise SystemExit(main()) | |