aTrapDeer commited on
Commit
48e3736
·
verified ·
1 Parent(s): eb5e699

Bootstrap Audio Flamingo 3 custom endpoint repo

Browse files
Files changed (3) hide show
  1. README.md +12 -10
  2. handler.py +171 -73
  3. requirements.txt +2 -12
README.md CHANGED
@@ -36,21 +36,23 @@ Then deploy a Dedicated Endpoint from that model repo.
36
  Important: make sure your endpoint repo contains top-level:
37
  - `handler.py`
38
  - `requirements.txt`
39
- - `sitecustomize.py`
40
- - `setup.py`
41
 
42
- from this folder.
43
- If logs say `No custom pipeline found at /repository/handler.py`, your files were not copied to repo root.
44
- Use Endpoint task `custom` so the runtime loads `handler.py` instead of the default Transformers pipeline.
45
 
46
- ## Endpoint env var
47
 
 
48
  - `AF3_MODEL_ID=nvidia/audio-flamingo-3-hf`
49
- - `PYTHONPATH=/repository` (ensures `sitecustomize.py` compatibility patch is loaded)
 
 
 
 
 
50
 
51
  ## Notes
52
 
53
  - Audio Flamingo 3 is large; use a GPU endpoint.
54
- - This handler returns raw prose analysis. Use the local AF3+ChatGPT pipeline to normalize to LoRA sidecar JSON.
55
- - If logs show `cannot import name 'is_tf_available' from transformers.file_utils`,
56
- ensure `sitecustomize.py` is present in repo root and endpoint env includes `PYTHONPATH=/repository`.
 
36
  Important: make sure your endpoint repo contains top-level:
37
  - `handler.py`
38
  - `requirements.txt`
39
+ - `README.md`
 
40
 
41
+ Use endpoint task `custom` so the runtime loads `handler.py` instead of a default Transformers pipeline.
 
 
42
 
43
+ ## Endpoint env vars
44
 
45
+ Required:
46
  - `AF3_MODEL_ID=nvidia/audio-flamingo-3-hf`
47
+
48
+ Optional runtime bootstrap (defaults shown):
49
+ - `AF3_BOOTSTRAP_RUNTIME=1`
50
+ - `AF3_TRANSFORMERS_SPEC=transformers==5.1.0`
51
+ - `AF3_RUNTIME_DIR=/tmp/af3_runtime`
52
+ - `AF3_STUB_TORCHVISION=1`
53
 
54
  ## Notes
55
 
56
  - Audio Flamingo 3 is large; use a GPU endpoint.
57
+ - First boot can take longer because the handler installs AF3-compatible runtime dependencies.
58
+ - This handler returns raw prose analysis. Use the local AF3+ChatGPT pipeline to normalize to LoRA sidecar JSON.
 
handler.py CHANGED
@@ -1,11 +1,15 @@
1
  import base64
 
 
2
  import os
3
- import tempfile
4
- from typing import Any, Dict, List
 
 
5
 
 
 
6
  import torch
7
- import transformers
8
- from transformers import AutoProcessor
9
 
10
 
11
  def _resolve_model_id(model_dir: str) -> str:
@@ -17,6 +21,134 @@ def _resolve_model_id(model_dir: str) -> str:
17
  return default_id
18
 
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  class EndpointHandler:
21
  """
22
  Hugging Face Dedicated Endpoint custom handler.
@@ -37,30 +169,17 @@ class EndpointHandler:
37
 
38
  def __init__(self, model_dir: str = ""):
39
  self.model_id = _resolve_model_id(model_dir)
40
- print(
41
- f"[AF3 handler] transformers={transformers.__version__} "
42
- f"AudioFlamingo3Processor={hasattr(transformers, 'AudioFlamingo3Processor')} "
43
- f"AudioFlamingo3ForConditionalGeneration={hasattr(transformers, 'AudioFlamingo3ForConditionalGeneration')}",
44
- flush=True,
 
45
  )
46
- try:
47
- from transformers import AudioFlamingo3ForConditionalGeneration
48
- model_cls = AudioFlamingo3ForConditionalGeneration
49
- except Exception:
50
- from transformers import AutoModelForImageTextToText
51
- model_cls = AutoModelForImageTextToText
52
 
53
  dtype = torch.float16 if torch.cuda.is_available() else torch.float32
54
- try:
55
- self.processor = AutoProcessor.from_pretrained(self.model_id, trust_remote_code=True)
56
- except Exception as exc:
57
- raise RuntimeError(
58
- "Failed to load AF3 processor. "
59
- f"transformers={transformers.__version__} "
60
- f"AudioFlamingo3Processor={hasattr(transformers, 'AudioFlamingo3Processor')} "
61
- f"model_id={self.model_id} error={exc}"
62
- ) from exc
63
- self.model = model_cls.from_pretrained(
64
  self.model_id,
65
  torch_dtype=dtype,
66
  trust_remote_code=True,
@@ -68,41 +187,31 @@ class EndpointHandler:
68
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
69
  self.model.to(self.device)
70
 
71
- def _build_inputs(self, audio_path: str, prompt: str) -> Dict[str, Any]:
72
- conversation_variants: List[List[Dict[str, Any]]] = [
73
- [
74
- {
75
- "role": "user",
76
- "content": [
77
- {"type": "audio", "path": audio_path},
78
- {"type": "text", "text": prompt},
79
- ],
80
- }
81
- ],
82
- [
83
- {
84
- "role": "user",
85
- "content": [
86
- {"type": "audio", "audio_url": audio_path},
87
- {"type": "text", "text": prompt},
88
- ],
89
- }
90
- ],
91
  ]
92
-
93
- last_exc: Exception | None = None
94
- for convo in conversation_variants:
95
- try:
96
- return self.processor.apply_chat_template(
97
- convo,
98
- tokenize=True,
99
- add_generation_prompt=True,
100
- return_dict=True,
101
- )
102
- except Exception as exc:
103
- last_exc = exc
104
- continue
105
- raise RuntimeError(f"Failed to build AF3 inputs from chat template: {last_exc}")
 
106
 
107
  def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
108
  payload = data.get("inputs", data) if isinstance(data, dict) else {}
@@ -114,14 +223,9 @@ class EndpointHandler:
114
  max_new_tokens = int(payload.get("max_new_tokens", 1200))
115
  temperature = float(payload.get("temperature", 0.1))
116
 
117
- tmp_path = ""
118
  try:
119
- raw = base64.b64decode(audio_b64)
120
- with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
121
- tmp.write(raw)
122
- tmp_path = tmp.name
123
-
124
- inputs = self._build_inputs(tmp_path, prompt)
125
  device = next(self.model.parameters()).device
126
  for key, value in list(inputs.items()):
127
  if hasattr(value, "to"):
@@ -144,10 +248,4 @@ class EndpointHandler:
144
  text = self.processor.batch_decode(outputs, skip_special_tokens=True)[0].strip()
145
  return {"generated_text": text}
146
  except Exception as exc:
147
- return {"error": str(exc)}
148
- finally:
149
- if tmp_path:
150
- try:
151
- os.unlink(tmp_path)
152
- except Exception:
153
- pass
 
1
  import base64
2
+ import importlib
3
+ import io
4
  import os
5
+ import subprocess
6
+ import sys
7
+ import types
8
+ from typing import Any, Dict, List, Tuple
9
 
10
+ import numpy as np
11
+ import soundfile as sf
12
  import torch
 
 
13
 
14
 
15
  def _resolve_model_id(model_dir: str) -> str:
 
21
  return default_id
22
 
23
 
24
+ def _log(msg: str) -> None:
25
+ print(f"[AF3 handler] {msg}", flush=True)
26
+
27
+
28
+ def _env_true(name: str, default: bool = False) -> bool:
29
+ raw = os.getenv(name)
30
+ if raw is None:
31
+ return default
32
+ return str(raw).strip().lower() in {"1", "true", "yes", "on"}
33
+
34
+
35
+ def _install_torchvision_stub() -> None:
36
+ if not _env_true("AF3_STUB_TORCHVISION", True):
37
+ return
38
+ if "torchvision" in sys.modules:
39
+ return
40
+ interpolation_mode = types.SimpleNamespace(
41
+ NEAREST=0,
42
+ BILINEAR=2,
43
+ BICUBIC=3,
44
+ BOX=4,
45
+ HAMMING=5,
46
+ LANCZOS=1,
47
+ )
48
+ transforms_stub = types.ModuleType("torchvision.transforms")
49
+ setattr(transforms_stub, "InterpolationMode", interpolation_mode)
50
+ tv_stub = types.ModuleType("torchvision")
51
+ setattr(tv_stub, "transforms", transforms_stub)
52
+ sys.modules["torchvision"] = tv_stub
53
+ sys.modules["torchvision.transforms"] = transforms_stub
54
+
55
+
56
+ def _clear_python_modules(prefixes: Tuple[str, ...]) -> None:
57
+ for name in list(sys.modules.keys()):
58
+ if any(name == p or name.startswith(f"{p}.") for p in prefixes):
59
+ sys.modules.pop(name, None)
60
+
61
+
62
+ def _af3_classes_available() -> tuple[bool, str]:
63
+ try:
64
+ from transformers import AudioFlamingo3ForConditionalGeneration # noqa: F401
65
+ from transformers import AudioFlamingo3Processor # noqa: F401
66
+
67
+ return True, ""
68
+ except Exception as exc:
69
+ return False, f"{type(exc).__name__}: {exc}"
70
+
71
+
72
+ def _bootstrap_runtime_transformers(target_dir: str) -> None:
73
+ packages = [
74
+ os.getenv("AF3_TRANSFORMERS_SPEC", "transformers==5.1.0"),
75
+ "numpy<2",
76
+ "accelerate>=1.1.0",
77
+ "sentencepiece",
78
+ "safetensors",
79
+ "soxr",
80
+ ]
81
+ cmd = [sys.executable, "-m", "pip", "install", "--upgrade", "--no-cache-dir", "--target", target_dir, *packages]
82
+ _log("Installing runtime deps for AF3 (first boot can take a few minutes).")
83
+ subprocess.check_call(cmd)
84
+
85
+
86
+ def _ensure_af3_transformers():
87
+ _install_torchvision_stub()
88
+
89
+ import transformers
90
+
91
+ ok, err = _af3_classes_available()
92
+ if ok:
93
+ _log(f"Using bundled transformers={transformers.__version__}")
94
+ return transformers
95
+
96
+ if not _env_true("AF3_BOOTSTRAP_RUNTIME", True):
97
+ raise RuntimeError(
98
+ "AF3 classes are unavailable in bundled transformers "
99
+ f"({transformers.__version__}) and AF3_BOOTSTRAP_RUNTIME is disabled. "
100
+ f"Last import error: {err}"
101
+ )
102
+
103
+ target_dir = os.getenv("AF3_RUNTIME_DIR", "/tmp/af3_runtime")
104
+ os.makedirs(target_dir, exist_ok=True)
105
+ _bootstrap_runtime_transformers(target_dir)
106
+ if target_dir not in sys.path:
107
+ sys.path.insert(0, target_dir)
108
+
109
+ _clear_python_modules(("transformers", "tokenizers", "huggingface_hub", "safetensors"))
110
+ _install_torchvision_stub()
111
+ importlib.invalidate_caches()
112
+ transformers = importlib.import_module("transformers")
113
+
114
+ ok, err = _af3_classes_available()
115
+ if not ok:
116
+ raise RuntimeError(
117
+ "Failed to load AF3 processor classes after runtime bootstrap. "
118
+ f"transformers={getattr(transformers, '__version__', 'unknown')} "
119
+ f"error={err}"
120
+ )
121
+ _log(f"Bootstrapped transformers={transformers.__version__}")
122
+ return transformers
123
+
124
+
125
+ def _resample_audio_mono(audio: np.ndarray, src_sr: int, dst_sr: int) -> np.ndarray:
126
+ if src_sr == dst_sr:
127
+ return audio.astype(np.float32, copy=False)
128
+ if audio.size == 0:
129
+ return np.zeros((0,), dtype=np.float32)
130
+ src_idx = np.arange(audio.shape[0], dtype=np.float64)
131
+ dst_len = int(round(audio.shape[0] * float(dst_sr) / float(src_sr)))
132
+ dst_len = max(dst_len, 1)
133
+ dst_idx = np.linspace(0.0, float(max(audio.shape[0] - 1, 0)), dst_len, dtype=np.float64)
134
+ out = np.interp(dst_idx, src_idx, audio.astype(np.float64, copy=False))
135
+ return out.astype(np.float32, copy=False)
136
+
137
+
138
+ def _decode_audio_from_b64(audio_b64: str) -> tuple[np.ndarray, int]:
139
+ raw = base64.b64decode(audio_b64)
140
+ data, sr = sf.read(io.BytesIO(raw), dtype="float32", always_2d=False)
141
+ if data.ndim == 2:
142
+ data = np.mean(data, axis=1)
143
+ if data.ndim != 1:
144
+ data = np.asarray(data).reshape(-1)
145
+ target_sr = 16000
146
+ if int(sr) != target_sr:
147
+ data = _resample_audio_mono(data, int(sr), target_sr)
148
+ sr = target_sr
149
+ return data.astype(np.float32, copy=False), int(sr)
150
+
151
+
152
  class EndpointHandler:
153
  """
154
  Hugging Face Dedicated Endpoint custom handler.
 
169
 
170
  def __init__(self, model_dir: str = ""):
171
  self.model_id = _resolve_model_id(model_dir)
172
+ self.transformers = _ensure_af3_transformers()
173
+ from transformers import AudioFlamingo3ForConditionalGeneration, AutoProcessor
174
+
175
+ _log(
176
+ f"torch={torch.__version__} cuda={torch.cuda.is_available()} "
177
+ f"transformers={self.transformers.__version__} model_id={self.model_id}"
178
  )
 
 
 
 
 
 
179
 
180
  dtype = torch.float16 if torch.cuda.is_available() else torch.float32
181
+ self.processor = AutoProcessor.from_pretrained(self.model_id, trust_remote_code=True)
182
+ self.model = AudioFlamingo3ForConditionalGeneration.from_pretrained(
 
 
 
 
 
 
 
 
183
  self.model_id,
184
  torch_dtype=dtype,
185
  trust_remote_code=True,
 
187
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
188
  self.model.to(self.device)
189
 
190
+ def _build_inputs(self, audio: np.ndarray, sample_rate: int, prompt: str) -> Dict[str, Any]:
191
+ conversation: List[Dict[str, Any]] = [
192
+ {
193
+ "role": "user",
194
+ "content": [
195
+ {"type": "audio", "audio": audio},
196
+ {"type": "text", "text": prompt},
197
+ ],
198
+ }
 
 
 
 
 
 
 
 
 
 
 
199
  ]
200
+ try:
201
+ return self.processor.apply_chat_template(
202
+ conversation,
203
+ tokenize=True,
204
+ add_generation_prompt=True,
205
+ return_dict=True,
206
+ audio_kwargs={"sampling_rate": int(sample_rate)},
207
+ )
208
+ except Exception:
209
+ return self.processor.apply_chat_template(
210
+ conversation,
211
+ tokenize=True,
212
+ add_generation_prompt=True,
213
+ return_dict=True,
214
+ )
215
 
216
  def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
217
  payload = data.get("inputs", data) if isinstance(data, dict) else {}
 
223
  max_new_tokens = int(payload.get("max_new_tokens", 1200))
224
  temperature = float(payload.get("temperature", 0.1))
225
 
 
226
  try:
227
+ audio, sample_rate = _decode_audio_from_b64(audio_b64)
228
+ inputs = self._build_inputs(audio, sample_rate, prompt)
 
 
 
 
229
  device = next(self.model.parameters()).device
230
  for key, value in list(inputs.items()):
231
  if hasattr(value, "to"):
 
248
  text = self.processor.batch_decode(outputs, skip_special_tokens=True)[0].strip()
249
  return {"generated_text": text}
250
  except Exception as exc:
251
+ return {"error": str(exc)}
 
 
 
 
 
 
requirements.txt CHANGED
@@ -1,12 +1,2 @@
1
- -e .
2
- torch
3
- torchaudio
4
- soundfile
5
- numpy
6
- transformers==5.0.0rc1
7
- huggingface_hub>=1.0.0
8
- accelerate>=1.0.0
9
- diffusers>=0.35.0
10
- peft>=0.17.0
11
- sentencepiece
12
- safetensors
 
1
+ numpy<2
2
+ soundfile