aTrapDeer's picture
Bootstrap Audio Flamingo 3 custom endpoint repo
4c99bf4 verified
import base64
import importlib
import importlib.machinery
import importlib.util
import io
import os
import subprocess
import sys
import types
from typing import Any, Dict, List, Tuple
import numpy as np
import soundfile as sf
import torch
def _resolve_model_id(model_dir: str) -> str:
default_id = os.getenv("AF3_MODEL_ID", "nvidia/audio-flamingo-3-hf")
if model_dir and os.path.isdir(model_dir):
has_local = os.path.exists(os.path.join(model_dir, "config.json"))
if has_local:
return model_dir
return default_id
def _log(msg: str) -> None:
print(f"[AF3 handler] {msg}", flush=True)
def _env_true(name: str, default: bool = False) -> bool:
raw = os.getenv(name)
if raw is None:
return default
return str(raw).strip().lower() in {"1", "true", "yes", "on"}
def _install_torchvision_stub() -> None:
if not _env_true("AF3_STUB_TORCHVISION", True):
return
interpolation_mode = types.SimpleNamespace(
NEAREST=0,
BILINEAR=2,
BICUBIC=3,
BOX=4,
HAMMING=5,
LANCZOS=1,
)
transforms_stub = types.ModuleType("torchvision.transforms")
setattr(transforms_stub, "InterpolationMode", interpolation_mode)
setattr(
transforms_stub,
"__spec__",
importlib.machinery.ModuleSpec(name="torchvision.transforms", loader=None),
)
tv_stub = types.ModuleType("torchvision")
setattr(tv_stub, "transforms", transforms_stub)
setattr(
tv_stub,
"__spec__",
importlib.machinery.ModuleSpec(name="torchvision", loader=None),
)
sys.modules["torchvision"] = tv_stub
sys.modules["torchvision.transforms"] = transforms_stub
_FIND_SPEC_PATCHED = False
def _patch_optional_backend_discovery() -> None:
global _FIND_SPEC_PATCHED
if _FIND_SPEC_PATCHED:
return
blocked = {"torchvision", "librosa"}
original_find_spec = importlib.util.find_spec
def wrapped_find_spec(name: str, package: str | None = None):
root = name.split(".", 1)[0]
if root in blocked:
return None
return original_find_spec(name, package)
importlib.util.find_spec = wrapped_find_spec # type: ignore[assignment]
_FIND_SPEC_PATCHED = True
def _clear_python_modules(prefixes: Tuple[str, ...]) -> None:
for name in list(sys.modules.keys()):
if any(name == p or name.startswith(f"{p}.") for p in prefixes):
sys.modules.pop(name, None)
def _patch_torch_compat() -> None:
try:
import torch._dynamo._trace_wrapped_higher_order_op as dyn_wrap
except Exception:
return
if hasattr(dyn_wrap, "TransformGetItemToIndex"):
return
class TransformGetItemToIndex: # pragma: no cover - runtime compatibility shim
pass
setattr(dyn_wrap, "TransformGetItemToIndex", TransformGetItemToIndex)
def _af3_classes_available() -> tuple[bool, str]:
try:
from transformers import AudioFlamingo3ForConditionalGeneration # noqa: F401
from transformers import AudioFlamingo3Processor # noqa: F401
return True, ""
except Exception as exc:
return False, f"{type(exc).__name__}: {exc}"
def _bootstrap_runtime_transformers(target_dir: str) -> None:
packages = [
os.getenv("AF3_TRANSFORMERS_SPEC", "transformers==5.1.0"),
"numpy<2",
"accelerate>=1.1.0",
"sentencepiece",
"safetensors",
"soxr",
]
cmd = [sys.executable, "-m", "pip", "install", "--upgrade", "--no-cache-dir", "--target", target_dir, *packages]
_log("Installing runtime deps for AF3 (first boot can take a few minutes).")
subprocess.check_call(cmd)
def _ensure_af3_transformers():
_patch_optional_backend_discovery()
_install_torchvision_stub()
_patch_torch_compat()
import transformers
ok, err = _af3_classes_available()
if ok:
_log(f"Using bundled transformers={transformers.__version__}")
return transformers
if not _env_true("AF3_BOOTSTRAP_RUNTIME", True):
raise RuntimeError(
"AF3 classes are unavailable in bundled transformers "
f"({transformers.__version__}) and AF3_BOOTSTRAP_RUNTIME is disabled. "
f"Last import error: {err}"
)
target_dir = os.getenv("AF3_RUNTIME_DIR", "/tmp/af3_runtime")
os.makedirs(target_dir, exist_ok=True)
_bootstrap_runtime_transformers(target_dir)
if target_dir not in sys.path:
sys.path.insert(0, target_dir)
_clear_python_modules(("transformers", "tokenizers", "huggingface_hub", "safetensors"))
_patch_optional_backend_discovery()
_install_torchvision_stub()
_patch_torch_compat()
importlib.invalidate_caches()
transformers = importlib.import_module("transformers")
ok, err = _af3_classes_available()
if not ok:
raise RuntimeError(
"Failed to load AF3 processor classes after runtime bootstrap. "
f"transformers={getattr(transformers, '__version__', 'unknown')} "
f"error={err}"
)
_log(f"Bootstrapped transformers={transformers.__version__}")
return transformers
def _resample_audio_mono(audio: np.ndarray, src_sr: int, dst_sr: int) -> np.ndarray:
if src_sr == dst_sr:
return audio.astype(np.float32, copy=False)
if audio.size == 0:
return np.zeros((0,), dtype=np.float32)
src_idx = np.arange(audio.shape[0], dtype=np.float64)
dst_len = int(round(audio.shape[0] * float(dst_sr) / float(src_sr)))
dst_len = max(dst_len, 1)
dst_idx = np.linspace(0.0, float(max(audio.shape[0] - 1, 0)), dst_len, dtype=np.float64)
out = np.interp(dst_idx, src_idx, audio.astype(np.float64, copy=False))
return out.astype(np.float32, copy=False)
def _decode_audio_from_b64(audio_b64: str) -> tuple[np.ndarray, int]:
raw = base64.b64decode(audio_b64)
data, sr = sf.read(io.BytesIO(raw), dtype="float32", always_2d=False)
if data.ndim == 2:
data = np.mean(data, axis=1)
if data.ndim != 1:
data = np.asarray(data).reshape(-1)
target_sr = 16000
if int(sr) != target_sr:
data = _resample_audio_mono(data, int(sr), target_sr)
sr = target_sr
return data.astype(np.float32, copy=False), int(sr)
class EndpointHandler:
"""
Hugging Face Dedicated Endpoint custom handler.
Request:
{
"inputs": {
"prompt": "...",
"audio_base64": "...",
"max_new_tokens": 1200,
"temperature": 0.1
}
}
Response:
{"generated_text": "..."}
"""
def __init__(self, model_dir: str = ""):
self.model_id = _resolve_model_id(model_dir)
self.transformers = _ensure_af3_transformers()
from transformers import AudioFlamingo3ForConditionalGeneration, AutoProcessor
_log(
f"torch={torch.__version__} cuda={torch.cuda.is_available()} "
f"transformers={self.transformers.__version__} model_id={self.model_id}"
)
dtype = torch.float16 if torch.cuda.is_available() else torch.float32
self.processor = AutoProcessor.from_pretrained(self.model_id, trust_remote_code=True)
self.model = AudioFlamingo3ForConditionalGeneration.from_pretrained(
self.model_id,
torch_dtype=dtype,
trust_remote_code=True,
)
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model.to(self.device)
def _build_inputs(self, audio: np.ndarray, sample_rate: int, prompt: str) -> Dict[str, Any]:
conversation: List[Dict[str, Any]] = [
{
"role": "user",
"content": [
{"type": "audio", "audio": audio},
{"type": "text", "text": prompt},
],
}
]
try:
return self.processor.apply_chat_template(
conversation,
tokenize=True,
add_generation_prompt=True,
return_dict=True,
return_tensors="pt",
audio_kwargs={"sampling_rate": int(sample_rate)},
)
except Exception:
return self.processor.apply_chat_template(
conversation,
tokenize=True,
add_generation_prompt=True,
return_dict=True,
return_tensors="pt",
)
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
payload = data.get("inputs", data) if isinstance(data, dict) else {}
prompt = str(payload.get("prompt", "Analyze this full song and summarize arrangement changes.")).strip()
audio_b64 = payload.get("audio_base64")
if not audio_b64:
return {"error": "audio_base64 is required"}
max_new_tokens = int(payload.get("max_new_tokens", 1200))
temperature = float(payload.get("temperature", 0.1))
try:
audio, sample_rate = _decode_audio_from_b64(audio_b64)
inputs = self._build_inputs(audio, sample_rate, prompt)
device = next(self.model.parameters()).device
model_dtype = next(self.model.parameters()).dtype
for key, value in list(inputs.items()):
if hasattr(value, "to"):
if hasattr(value, "dtype") and torch.is_floating_point(value):
inputs[key] = value.to(device=device, dtype=model_dtype)
else:
inputs[key] = value.to(device)
do_sample = bool(temperature > 0)
gen_kwargs = {
"max_new_tokens": max_new_tokens,
"do_sample": do_sample,
}
if do_sample:
gen_kwargs["temperature"] = max(temperature, 1e-5)
with torch.no_grad():
outputs = self.model.generate(**inputs, **gen_kwargs)
start = int(inputs["input_ids"].shape[1])
text = self.processor.batch_decode(outputs[:, start:], skip_special_tokens=True)[0].strip()
if not text:
text = self.processor.batch_decode(outputs, skip_special_tokens=True)[0].strip()
return {"generated_text": text}
except Exception as exc:
return {"error": str(exc)}