Spaces:
Running
on
Zero
Running
on
Zero
File size: 13,038 Bytes
20e9692 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 |
"""AoTInductor compilation utilities for the VAD segmenter."""
import torch
from config import (
AOTI_ENABLED, AOTI_MIN_AUDIO_MINUTES, AOTI_MAX_AUDIO_MINUTES,
AOTI_HUB_REPO, AOTI_HUB_ENABLED,
)
from .segmenter_model import _segmenter_cache
# =============================================================================
# AoT Compilation Test
# =============================================================================
_aoti_cache = {
"exported": None,
"compiled": None,
"tested": False,
}
def is_aoti_applied() -> bool:
"""Return True if a compiled AoTI model has been applied."""
return bool(_aoti_cache.get("applied"))
def _get_aoti_hub_filename():
"""Generate Hub filename encoding min/max audio duration."""
return f"vad_aoti_{AOTI_MIN_AUDIO_MINUTES}min_{AOTI_MAX_AUDIO_MINUTES}min.pt2"
def _try_load_aoti_from_hub(model):
"""
Try to load a pre-compiled AoTI model from Hub.
Returns True if successful, False otherwise.
"""
import os
import time
if not AOTI_HUB_ENABLED:
print("[AoTI] Hub persistence disabled")
return False
token = os.environ.get("HF_TOKEN")
if not token:
print("[AoTI] HF_TOKEN not set, cannot access Hub")
return False
filename = _get_aoti_hub_filename()
print(f"[AoTI] Checking Hub for pre-compiled model: {AOTI_HUB_REPO}/{filename}")
try:
from huggingface_hub import hf_hub_download, HfApi
# Check if file exists in repo
api = HfApi(token=token)
try:
files = api.list_repo_files(AOTI_HUB_REPO, token=token)
if filename not in files:
print(f"[AoTI] Compiled model not found on Hub (available: {files})")
return False
except Exception as e:
print(f"[AoTI] Could not list Hub repo: {e}")
return False
# Download the compiled graph
t0 = time.time()
compiled_graph_file = hf_hub_download(
AOTI_HUB_REPO, filename, token=token
)
download_time = time.time() - t0
print(f"[AoTI] Downloaded from Hub in {download_time:.1f}s: {compiled_graph_file}")
# Load using ZeroGPU AOTI utilities
from spaces.zero.torch.aoti import ZeroGPUCompiledModel, ZeroGPUWeights, drain_module_parameters
state_dict = model.state_dict()
zerogpu_weights = ZeroGPUWeights({name: weight for name, weight in state_dict.items()})
compiled = ZeroGPUCompiledModel(compiled_graph_file, zerogpu_weights)
# Replace forward method
setattr(model, "forward", compiled)
drain_module_parameters(model)
_aoti_cache["compiled"] = compiled
_aoti_cache["applied"] = True
print(f"[AoTI] Loaded and applied compiled model from Hub")
return True
except Exception as e:
print(f"[AoTI] Failed to load from Hub: {type(e).__name__}: {e}")
import traceback
traceback.print_exc()
return False
def _push_aoti_to_hub(compiled):
"""
Push compiled AoTI model to Hub for future reuse.
"""
import os
import time
import tempfile
if not AOTI_HUB_ENABLED:
print("[AoTI] Hub persistence disabled, skipping upload")
return False
token = os.environ.get("HF_TOKEN")
if not token:
print("[AoTI] HF_TOKEN not set, cannot upload to Hub")
return False
filename = _get_aoti_hub_filename()
print(f"[AoTI] Uploading compiled model to Hub: {AOTI_HUB_REPO}/{filename}")
try:
from huggingface_hub import HfApi, create_repo
api = HfApi(token=token)
# Create repo if it doesn't exist
try:
create_repo(AOTI_HUB_REPO, exist_ok=True, token=token)
except Exception as e:
print(f"[AoTI] Repo creation note: {e}")
# Get the archive file from the compiled object
archive = compiled.archive_file
if archive is None:
print("[AoTI] Compiled object has no archive_file, cannot upload")
return False
t0 = time.time()
# Write archive to temp file and upload
with tempfile.TemporaryDirectory() as tmpdir:
output_path = os.path.join(tmpdir, filename)
# archive is a BytesIO object
with open(output_path, "wb") as f:
f.write(archive.getvalue())
info = api.upload_file(
repo_id=AOTI_HUB_REPO,
path_or_fileobj=output_path,
path_in_repo=filename,
commit_message=f"Add compiled VAD model ({AOTI_MIN_AUDIO_MINUTES}-{AOTI_MAX_AUDIO_MINUTES} min)",
token=token,
)
upload_time = time.time() - t0
print(f"[AoTI] Uploaded to Hub in {upload_time:.1f}s: {info}")
return True
except Exception as e:
print(f"[AoTI] Failed to upload to Hub: {type(e).__name__}: {e}")
import traceback
traceback.print_exc()
return False
def test_vad_aoti_export():
"""
Test torch.export AoT compilation for VAD model using spaces.aoti_capture.
Must be called AFTER model is on GPU (inside GPU-decorated function).
Checks Hub for pre-compiled model first. If found, loads it directly.
Otherwise, compiles fresh and uploads to Hub for future reuse.
Uses aoti_capture to capture the EXACT call signature from a real inference
call to segment_recitations, ensuring the export matches what the model
actually receives during inference.
Returns dict with test results and timing.
"""
import time
results = {
"export_success": False,
"export_time": 0.0,
"compile_success": False,
"compile_time": 0.0,
"hub_loaded": False,
"hub_uploaded": False,
"error": None,
}
if not AOTI_ENABLED:
results["error"] = "AoTI disabled in config"
print("[AoTI] Disabled via AOTI_ENABLED=False")
return results
if _aoti_cache["tested"]:
print("[AoTI] Already tested this session, skipping")
return {"skipped": True, **results}
_aoti_cache["tested"] = True
# Check model is loaded and on GPU
if not _segmenter_cache["loaded"] or _segmenter_cache["model"] is None:
results["error"] = "Model not loaded"
print(f"[AoTI] {results['error']}")
return results
model = _segmenter_cache["model"]
processor = _segmenter_cache["processor"]
device = next(model.parameters()).device
dtype = next(model.parameters()).dtype
if device.type != "cuda":
results["error"] = f"Model not on GPU (device={device})"
print(f"[AoTI] {results['error']}")
return results
print(f"[AoTI] Testing torch.export on VAD model (device={device}, dtype={dtype})")
# Import spaces for aoti_capture
try:
import spaces
except ImportError:
results["error"] = "spaces module not available"
print(f"[AoTI] {results['error']}")
return results
# Try to load pre-compiled model from Hub first
if _try_load_aoti_from_hub(model):
results["hub_loaded"] = True
results["compile_success"] = True
print("[AoTI] Using pre-compiled model from Hub")
return results
# No cached model found - compile fresh
print("[AoTI] No cached model on Hub, compiling fresh...")
# Convert config minutes to samples (16kHz audio)
SAMPLES_PER_MINUTE = 16000 * 60
min_samples = int(AOTI_MIN_AUDIO_MINUTES * SAMPLES_PER_MINUTE)
max_samples = int(AOTI_MAX_AUDIO_MINUTES * SAMPLES_PER_MINUTE)
# Create test audio for capture - use min duration to save memory
# MUST be on CPU - segment_recitations moves to GPU internally
test_audio = torch.randn(min_samples, device="cpu")
print(f"[AoTI] Test audio: {min_samples} samples ({AOTI_MIN_AUDIO_MINUTES} min)")
# Capture the exact args/kwargs used by segment_recitations
try:
from recitations_segmenter import segment_recitations
print("[AoTI] Capturing call signature via aoti_capture...")
with spaces.aoti_capture(model) as call:
segment_recitations(
[test_audio], model, processor,
device=device, dtype=dtype, batch_size=1,
)
print(f"[AoTI] Captured args: {len(call.args)} positional, {list(call.kwargs.keys())} kwargs")
except Exception as e:
results["error"] = f"aoti_capture failed: {type(e).__name__}: {e}"
print(f"[AoTI] {results['error']}")
import traceback
traceback.print_exc()
return results
# Build dynamic shapes from captured tensors
# The sequence dimension (T) varies with audio length
try:
from torch.export import export, Dim
# Derive frame rate from captured tensor (model's actual output rate)
# Find the first 2D+ tensor to get the captured frame count
captured_frames = None
for val in list(call.kwargs.values()) + list(call.args):
if isinstance(val, torch.Tensor) and val.dim() >= 2:
captured_frames = val.shape[1]
break
if captured_frames is None:
raise ValueError("No 2D+ tensor found in captured args/kwargs")
# Calculate frames per minute from captured data
frames_per_minute = captured_frames / AOTI_MIN_AUDIO_MINUTES
min_frames = captured_frames # Already at min duration
max_frames = int(AOTI_MAX_AUDIO_MINUTES * frames_per_minute)
dynamic_T = Dim("T", min=min_frames, max=max_frames)
print(f"[AoTI] Captured {captured_frames} frames for {AOTI_MIN_AUDIO_MINUTES} min = {frames_per_minute:.1f} frames/min")
print(f"[AoTI] Dynamic shape range: {min_frames}-{max_frames} frames")
# Build dynamic_shapes dict matching the captured signature
dynamic_shapes_args = []
for arg in call.args:
if isinstance(arg, torch.Tensor) and arg.dim() >= 2:
# Assume sequence dim is dim 1 for 2D+ tensors
dynamic_shapes_args.append({1: dynamic_T})
else:
dynamic_shapes_args.append(None)
dynamic_shapes_kwargs = {}
for key, val in call.kwargs.items():
if isinstance(val, torch.Tensor) and val.dim() >= 2:
dynamic_shapes_kwargs[key] = {1: dynamic_T}
else:
dynamic_shapes_kwargs[key] = None
print(f"[AoTI] Dynamic shapes - args: {dynamic_shapes_args}, kwargs: {list(dynamic_shapes_kwargs.keys())}")
t0 = time.time()
# Export using captured signature - guarantees match with inference
exported = export(
model,
args=call.args,
kwargs=call.kwargs,
dynamic_shapes=(dynamic_shapes_args, dynamic_shapes_kwargs) if dynamic_shapes_args else dynamic_shapes_kwargs,
strict=False,
)
results["export_time"] = time.time() - t0
results["export_success"] = True
_aoti_cache["exported"] = exported
print(f"[AoTI] torch.export SUCCESS in {results['export_time']:.1f}s")
except Exception as e:
results["error"] = f"torch.export failed: {type(e).__name__}: {e}"
print(f"[AoTI] {results['error']}")
import traceback
traceback.print_exc()
return results
# Attempt spaces.aoti_compile
try:
t0 = time.time()
compiled = spaces.aoti_compile(exported)
results["compile_time"] = time.time() - t0
results["compile_success"] = True
_aoti_cache["compiled"] = compiled
print(f"[AoTI] spaces.aoti_compile SUCCESS in {results['compile_time']:.1f}s")
# Return compiled object - apply happens OUTSIDE GPU lease (in main process)
results["compiled"] = compiled
print(f"[AoTI] Compiled object ready for apply")
# Upload to Hub for future reuse
if _push_aoti_to_hub(compiled):
results["hub_uploaded"] = True
except Exception as e:
results["error"] = f"aoti_compile failed: {type(e).__name__}: {e}"
print(f"[AoTI] {results['error']}")
import traceback
traceback.print_exc()
return results
def apply_aoti_compiled(compiled):
"""
Apply AoTI compiled model to VAD segmenter.
Must be called OUTSIDE GPU lease, in main process.
"""
if compiled is None:
print("[AoTI] No compiled object to apply")
return False
model = _segmenter_cache.get("model")
if model is None:
print("[AoTI] Model not loaded, cannot apply")
return False
try:
import spaces
spaces.aoti_apply(compiled, model)
_aoti_cache["compiled"] = compiled
_aoti_cache["applied"] = True
print(f"[AoTI] Compiled model applied to VAD (model_id={id(model)})")
return True
except Exception as e:
print(f"[AoTI] Apply failed: {e}")
return False
|