Spaces:
Running
on
Zero
Running
on
Zero
| """AoTInductor compilation utilities for the VAD segmenter.""" | |
| import torch | |
| from config import ( | |
| AOTI_ENABLED, AOTI_MIN_AUDIO_MINUTES, AOTI_MAX_AUDIO_MINUTES, | |
| AOTI_HUB_REPO, AOTI_HUB_ENABLED, | |
| ) | |
| from .segmenter_model import _segmenter_cache | |
| # ============================================================================= | |
| # AoT Compilation Test | |
| # ============================================================================= | |
| _aoti_cache = { | |
| "exported": None, | |
| "compiled": None, | |
| "tested": False, | |
| } | |
| def is_aoti_applied() -> bool: | |
| """Return True if a compiled AoTI model has been applied.""" | |
| return bool(_aoti_cache.get("applied")) | |
| def _get_aoti_hub_filename(): | |
| """Generate Hub filename encoding min/max audio duration.""" | |
| return f"vad_aoti_{AOTI_MIN_AUDIO_MINUTES}min_{AOTI_MAX_AUDIO_MINUTES}min.pt2" | |
| def _try_load_aoti_from_hub(model): | |
| """ | |
| Try to load a pre-compiled AoTI model from Hub. | |
| Returns True if successful, False otherwise. | |
| """ | |
| import os | |
| import time | |
| if not AOTI_HUB_ENABLED: | |
| print("[AoTI] Hub persistence disabled") | |
| return False | |
| token = os.environ.get("HF_TOKEN") | |
| if not token: | |
| print("[AoTI] HF_TOKEN not set, cannot access Hub") | |
| return False | |
| filename = _get_aoti_hub_filename() | |
| print(f"[AoTI] Checking Hub for pre-compiled model: {AOTI_HUB_REPO}/{filename}") | |
| try: | |
| from huggingface_hub import hf_hub_download, HfApi | |
| # Check if file exists in repo | |
| api = HfApi(token=token) | |
| try: | |
| files = api.list_repo_files(AOTI_HUB_REPO, token=token) | |
| if filename not in files: | |
| print(f"[AoTI] Compiled model not found on Hub (available: {files})") | |
| return False | |
| except Exception as e: | |
| print(f"[AoTI] Could not list Hub repo: {e}") | |
| return False | |
| # Download the compiled graph | |
| t0 = time.time() | |
| compiled_graph_file = hf_hub_download( | |
| AOTI_HUB_REPO, filename, token=token | |
| ) | |
| download_time = time.time() - t0 | |
| print(f"[AoTI] Downloaded from Hub in {download_time:.1f}s: {compiled_graph_file}") | |
| # Load using ZeroGPU AOTI utilities | |
| from spaces.zero.torch.aoti import ZeroGPUCompiledModel, ZeroGPUWeights, drain_module_parameters | |
| state_dict = model.state_dict() | |
| zerogpu_weights = ZeroGPUWeights({name: weight for name, weight in state_dict.items()}) | |
| compiled = ZeroGPUCompiledModel(compiled_graph_file, zerogpu_weights) | |
| # Replace forward method | |
| setattr(model, "forward", compiled) | |
| drain_module_parameters(model) | |
| _aoti_cache["compiled"] = compiled | |
| _aoti_cache["applied"] = True | |
| print(f"[AoTI] Loaded and applied compiled model from Hub") | |
| return True | |
| except Exception as e: | |
| print(f"[AoTI] Failed to load from Hub: {type(e).__name__}: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return False | |
| def _push_aoti_to_hub(compiled): | |
| """ | |
| Push compiled AoTI model to Hub for future reuse. | |
| """ | |
| import os | |
| import time | |
| import tempfile | |
| if not AOTI_HUB_ENABLED: | |
| print("[AoTI] Hub persistence disabled, skipping upload") | |
| return False | |
| token = os.environ.get("HF_TOKEN") | |
| if not token: | |
| print("[AoTI] HF_TOKEN not set, cannot upload to Hub") | |
| return False | |
| filename = _get_aoti_hub_filename() | |
| print(f"[AoTI] Uploading compiled model to Hub: {AOTI_HUB_REPO}/{filename}") | |
| try: | |
| from huggingface_hub import HfApi, create_repo | |
| api = HfApi(token=token) | |
| # Create repo if it doesn't exist | |
| try: | |
| create_repo(AOTI_HUB_REPO, exist_ok=True, token=token) | |
| except Exception as e: | |
| print(f"[AoTI] Repo creation note: {e}") | |
| # Get the archive file from the compiled object | |
| archive = compiled.archive_file | |
| if archive is None: | |
| print("[AoTI] Compiled object has no archive_file, cannot upload") | |
| return False | |
| t0 = time.time() | |
| # Write archive to temp file and upload | |
| with tempfile.TemporaryDirectory() as tmpdir: | |
| output_path = os.path.join(tmpdir, filename) | |
| # archive is a BytesIO object | |
| with open(output_path, "wb") as f: | |
| f.write(archive.getvalue()) | |
| info = api.upload_file( | |
| repo_id=AOTI_HUB_REPO, | |
| path_or_fileobj=output_path, | |
| path_in_repo=filename, | |
| commit_message=f"Add compiled VAD model ({AOTI_MIN_AUDIO_MINUTES}-{AOTI_MAX_AUDIO_MINUTES} min)", | |
| token=token, | |
| ) | |
| upload_time = time.time() - t0 | |
| print(f"[AoTI] Uploaded to Hub in {upload_time:.1f}s: {info}") | |
| return True | |
| except Exception as e: | |
| print(f"[AoTI] Failed to upload to Hub: {type(e).__name__}: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return False | |
| def test_vad_aoti_export(): | |
| """ | |
| Test torch.export AoT compilation for VAD model using spaces.aoti_capture. | |
| Must be called AFTER model is on GPU (inside GPU-decorated function). | |
| Checks Hub for pre-compiled model first. If found, loads it directly. | |
| Otherwise, compiles fresh and uploads to Hub for future reuse. | |
| Uses aoti_capture to capture the EXACT call signature from a real inference | |
| call to segment_recitations, ensuring the export matches what the model | |
| actually receives during inference. | |
| Returns dict with test results and timing. | |
| """ | |
| import time | |
| results = { | |
| "export_success": False, | |
| "export_time": 0.0, | |
| "compile_success": False, | |
| "compile_time": 0.0, | |
| "hub_loaded": False, | |
| "hub_uploaded": False, | |
| "error": None, | |
| } | |
| if not AOTI_ENABLED: | |
| results["error"] = "AoTI disabled in config" | |
| print("[AoTI] Disabled via AOTI_ENABLED=False") | |
| return results | |
| if _aoti_cache["tested"]: | |
| print("[AoTI] Already tested this session, skipping") | |
| return {"skipped": True, **results} | |
| _aoti_cache["tested"] = True | |
| # Check model is loaded and on GPU | |
| if not _segmenter_cache["loaded"] or _segmenter_cache["model"] is None: | |
| results["error"] = "Model not loaded" | |
| print(f"[AoTI] {results['error']}") | |
| return results | |
| model = _segmenter_cache["model"] | |
| processor = _segmenter_cache["processor"] | |
| device = next(model.parameters()).device | |
| dtype = next(model.parameters()).dtype | |
| if device.type != "cuda": | |
| results["error"] = f"Model not on GPU (device={device})" | |
| print(f"[AoTI] {results['error']}") | |
| return results | |
| print(f"[AoTI] Testing torch.export on VAD model (device={device}, dtype={dtype})") | |
| # Import spaces for aoti_capture | |
| try: | |
| import spaces | |
| except ImportError: | |
| results["error"] = "spaces module not available" | |
| print(f"[AoTI] {results['error']}") | |
| return results | |
| # Try to load pre-compiled model from Hub first | |
| if _try_load_aoti_from_hub(model): | |
| results["hub_loaded"] = True | |
| results["compile_success"] = True | |
| print("[AoTI] Using pre-compiled model from Hub") | |
| return results | |
| # No cached model found - compile fresh | |
| print("[AoTI] No cached model on Hub, compiling fresh...") | |
| # Convert config minutes to samples (16kHz audio) | |
| SAMPLES_PER_MINUTE = 16000 * 60 | |
| min_samples = int(AOTI_MIN_AUDIO_MINUTES * SAMPLES_PER_MINUTE) | |
| max_samples = int(AOTI_MAX_AUDIO_MINUTES * SAMPLES_PER_MINUTE) | |
| # Create test audio for capture - use min duration to save memory | |
| # MUST be on CPU - segment_recitations moves to GPU internally | |
| test_audio = torch.randn(min_samples, device="cpu") | |
| print(f"[AoTI] Test audio: {min_samples} samples ({AOTI_MIN_AUDIO_MINUTES} min)") | |
| # Capture the exact args/kwargs used by segment_recitations | |
| try: | |
| from recitations_segmenter import segment_recitations | |
| print("[AoTI] Capturing call signature via aoti_capture...") | |
| with spaces.aoti_capture(model) as call: | |
| segment_recitations( | |
| [test_audio], model, processor, | |
| device=device, dtype=dtype, batch_size=1, | |
| ) | |
| print(f"[AoTI] Captured args: {len(call.args)} positional, {list(call.kwargs.keys())} kwargs") | |
| except Exception as e: | |
| results["error"] = f"aoti_capture failed: {type(e).__name__}: {e}" | |
| print(f"[AoTI] {results['error']}") | |
| import traceback | |
| traceback.print_exc() | |
| return results | |
| # Build dynamic shapes from captured tensors | |
| # The sequence dimension (T) varies with audio length | |
| try: | |
| from torch.export import export, Dim | |
| # Derive frame rate from captured tensor (model's actual output rate) | |
| # Find the first 2D+ tensor to get the captured frame count | |
| captured_frames = None | |
| for val in list(call.kwargs.values()) + list(call.args): | |
| if isinstance(val, torch.Tensor) and val.dim() >= 2: | |
| captured_frames = val.shape[1] | |
| break | |
| if captured_frames is None: | |
| raise ValueError("No 2D+ tensor found in captured args/kwargs") | |
| # Calculate frames per minute from captured data | |
| frames_per_minute = captured_frames / AOTI_MIN_AUDIO_MINUTES | |
| min_frames = captured_frames # Already at min duration | |
| max_frames = int(AOTI_MAX_AUDIO_MINUTES * frames_per_minute) | |
| dynamic_T = Dim("T", min=min_frames, max=max_frames) | |
| print(f"[AoTI] Captured {captured_frames} frames for {AOTI_MIN_AUDIO_MINUTES} min = {frames_per_minute:.1f} frames/min") | |
| print(f"[AoTI] Dynamic shape range: {min_frames}-{max_frames} frames") | |
| # Build dynamic_shapes dict matching the captured signature | |
| dynamic_shapes_args = [] | |
| for arg in call.args: | |
| if isinstance(arg, torch.Tensor) and arg.dim() >= 2: | |
| # Assume sequence dim is dim 1 for 2D+ tensors | |
| dynamic_shapes_args.append({1: dynamic_T}) | |
| else: | |
| dynamic_shapes_args.append(None) | |
| dynamic_shapes_kwargs = {} | |
| for key, val in call.kwargs.items(): | |
| if isinstance(val, torch.Tensor) and val.dim() >= 2: | |
| dynamic_shapes_kwargs[key] = {1: dynamic_T} | |
| else: | |
| dynamic_shapes_kwargs[key] = None | |
| print(f"[AoTI] Dynamic shapes - args: {dynamic_shapes_args}, kwargs: {list(dynamic_shapes_kwargs.keys())}") | |
| t0 = time.time() | |
| # Export using captured signature - guarantees match with inference | |
| exported = export( | |
| model, | |
| args=call.args, | |
| kwargs=call.kwargs, | |
| dynamic_shapes=(dynamic_shapes_args, dynamic_shapes_kwargs) if dynamic_shapes_args else dynamic_shapes_kwargs, | |
| strict=False, | |
| ) | |
| results["export_time"] = time.time() - t0 | |
| results["export_success"] = True | |
| _aoti_cache["exported"] = exported | |
| print(f"[AoTI] torch.export SUCCESS in {results['export_time']:.1f}s") | |
| except Exception as e: | |
| results["error"] = f"torch.export failed: {type(e).__name__}: {e}" | |
| print(f"[AoTI] {results['error']}") | |
| import traceback | |
| traceback.print_exc() | |
| return results | |
| # Attempt spaces.aoti_compile | |
| try: | |
| t0 = time.time() | |
| compiled = spaces.aoti_compile(exported) | |
| results["compile_time"] = time.time() - t0 | |
| results["compile_success"] = True | |
| _aoti_cache["compiled"] = compiled | |
| print(f"[AoTI] spaces.aoti_compile SUCCESS in {results['compile_time']:.1f}s") | |
| # Return compiled object - apply happens OUTSIDE GPU lease (in main process) | |
| results["compiled"] = compiled | |
| print(f"[AoTI] Compiled object ready for apply") | |
| # Upload to Hub for future reuse | |
| if _push_aoti_to_hub(compiled): | |
| results["hub_uploaded"] = True | |
| except Exception as e: | |
| results["error"] = f"aoti_compile failed: {type(e).__name__}: {e}" | |
| print(f"[AoTI] {results['error']}") | |
| import traceback | |
| traceback.print_exc() | |
| return results | |
| def apply_aoti_compiled(compiled): | |
| """ | |
| Apply AoTI compiled model to VAD segmenter. | |
| Must be called OUTSIDE GPU lease, in main process. | |
| """ | |
| if compiled is None: | |
| print("[AoTI] No compiled object to apply") | |
| return False | |
| model = _segmenter_cache.get("model") | |
| if model is None: | |
| print("[AoTI] Model not loaded, cannot apply") | |
| return False | |
| try: | |
| import spaces | |
| spaces.aoti_apply(compiled, model) | |
| _aoti_cache["compiled"] = compiled | |
| _aoti_cache["applied"] = True | |
| print(f"[AoTI] Compiled model applied to VAD (model_id={id(model)})") | |
| return True | |
| except Exception as e: | |
| print(f"[AoTI] Apply failed: {e}") | |
| return False | |