File size: 13,038 Bytes
20e9692
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
"""AoTInductor compilation utilities for the VAD segmenter."""

import torch

from config import (
    AOTI_ENABLED, AOTI_MIN_AUDIO_MINUTES, AOTI_MAX_AUDIO_MINUTES,
    AOTI_HUB_REPO, AOTI_HUB_ENABLED,
)
from .segmenter_model import _segmenter_cache


# =============================================================================
# AoT Compilation Test
# =============================================================================

_aoti_cache = {
    "exported": None,
    "compiled": None,
    "tested": False,
}


def is_aoti_applied() -> bool:
    """Return True if a compiled AoTI model has been applied."""
    return bool(_aoti_cache.get("applied"))


def _get_aoti_hub_filename():
    """Generate Hub filename encoding min/max audio duration."""
    return f"vad_aoti_{AOTI_MIN_AUDIO_MINUTES}min_{AOTI_MAX_AUDIO_MINUTES}min.pt2"


def _try_load_aoti_from_hub(model):
    """
    Try to load a pre-compiled AoTI model from Hub.
    Returns True if successful, False otherwise.
    """
    import os
    import time

    if not AOTI_HUB_ENABLED:
        print("[AoTI] Hub persistence disabled")
        return False

    token = os.environ.get("HF_TOKEN")
    if not token:
        print("[AoTI] HF_TOKEN not set, cannot access Hub")
        return False

    filename = _get_aoti_hub_filename()
    print(f"[AoTI] Checking Hub for pre-compiled model: {AOTI_HUB_REPO}/{filename}")

    try:
        from huggingface_hub import hf_hub_download, HfApi

        # Check if file exists in repo
        api = HfApi(token=token)
        try:
            files = api.list_repo_files(AOTI_HUB_REPO, token=token)
            if filename not in files:
                print(f"[AoTI] Compiled model not found on Hub (available: {files})")
                return False
        except Exception as e:
            print(f"[AoTI] Could not list Hub repo: {e}")
            return False

        # Download the compiled graph
        t0 = time.time()
        compiled_graph_file = hf_hub_download(
            AOTI_HUB_REPO, filename, token=token
        )
        download_time = time.time() - t0
        print(f"[AoTI] Downloaded from Hub in {download_time:.1f}s: {compiled_graph_file}")

        # Load using ZeroGPU AOTI utilities
        from spaces.zero.torch.aoti import ZeroGPUCompiledModel, ZeroGPUWeights, drain_module_parameters

        state_dict = model.state_dict()
        zerogpu_weights = ZeroGPUWeights({name: weight for name, weight in state_dict.items()})
        compiled = ZeroGPUCompiledModel(compiled_graph_file, zerogpu_weights)

        # Replace forward method
        setattr(model, "forward", compiled)
        drain_module_parameters(model)

        _aoti_cache["compiled"] = compiled
        _aoti_cache["applied"] = True
        print(f"[AoTI] Loaded and applied compiled model from Hub")
        return True

    except Exception as e:
        print(f"[AoTI] Failed to load from Hub: {type(e).__name__}: {e}")
        import traceback
        traceback.print_exc()
        return False


def _push_aoti_to_hub(compiled):
    """
    Push compiled AoTI model to Hub for future reuse.
    """
    import os
    import time
    import tempfile

    if not AOTI_HUB_ENABLED:
        print("[AoTI] Hub persistence disabled, skipping upload")
        return False

    token = os.environ.get("HF_TOKEN")
    if not token:
        print("[AoTI] HF_TOKEN not set, cannot upload to Hub")
        return False

    filename = _get_aoti_hub_filename()
    print(f"[AoTI] Uploading compiled model to Hub: {AOTI_HUB_REPO}/{filename}")

    try:
        from huggingface_hub import HfApi, create_repo

        api = HfApi(token=token)

        # Create repo if it doesn't exist
        try:
            create_repo(AOTI_HUB_REPO, exist_ok=True, token=token)
        except Exception as e:
            print(f"[AoTI] Repo creation note: {e}")

        # Get the archive file from the compiled object
        archive = compiled.archive_file
        if archive is None:
            print("[AoTI] Compiled object has no archive_file, cannot upload")
            return False

        t0 = time.time()

        # Write archive to temp file and upload
        with tempfile.TemporaryDirectory() as tmpdir:
            output_path = os.path.join(tmpdir, filename)

            # archive is a BytesIO object
            with open(output_path, "wb") as f:
                f.write(archive.getvalue())

            info = api.upload_file(
                repo_id=AOTI_HUB_REPO,
                path_or_fileobj=output_path,
                path_in_repo=filename,
                commit_message=f"Add compiled VAD model ({AOTI_MIN_AUDIO_MINUTES}-{AOTI_MAX_AUDIO_MINUTES} min)",
                token=token,
            )

        upload_time = time.time() - t0
        print(f"[AoTI] Uploaded to Hub in {upload_time:.1f}s: {info}")
        return True

    except Exception as e:
        print(f"[AoTI] Failed to upload to Hub: {type(e).__name__}: {e}")
        import traceback
        traceback.print_exc()
        return False


def test_vad_aoti_export():
    """
    Test torch.export AoT compilation for VAD model using spaces.aoti_capture.
    Must be called AFTER model is on GPU (inside GPU-decorated function).

    Checks Hub for pre-compiled model first. If found, loads it directly.
    Otherwise, compiles fresh and uploads to Hub for future reuse.

    Uses aoti_capture to capture the EXACT call signature from a real inference
    call to segment_recitations, ensuring the export matches what the model
    actually receives during inference.

    Returns dict with test results and timing.
    """
    import time

    results = {
        "export_success": False,
        "export_time": 0.0,
        "compile_success": False,
        "compile_time": 0.0,
        "hub_loaded": False,
        "hub_uploaded": False,
        "error": None,
    }

    if not AOTI_ENABLED:
        results["error"] = "AoTI disabled in config"
        print("[AoTI] Disabled via AOTI_ENABLED=False")
        return results

    if _aoti_cache["tested"]:
        print("[AoTI] Already tested this session, skipping")
        return {"skipped": True, **results}

    _aoti_cache["tested"] = True

    # Check model is loaded and on GPU
    if not _segmenter_cache["loaded"] or _segmenter_cache["model"] is None:
        results["error"] = "Model not loaded"
        print(f"[AoTI] {results['error']}")
        return results

    model = _segmenter_cache["model"]
    processor = _segmenter_cache["processor"]
    device = next(model.parameters()).device
    dtype = next(model.parameters()).dtype

    if device.type != "cuda":
        results["error"] = f"Model not on GPU (device={device})"
        print(f"[AoTI] {results['error']}")
        return results

    print(f"[AoTI] Testing torch.export on VAD model (device={device}, dtype={dtype})")

    # Import spaces for aoti_capture
    try:
        import spaces
    except ImportError:
        results["error"] = "spaces module not available"
        print(f"[AoTI] {results['error']}")
        return results

    # Try to load pre-compiled model from Hub first
    if _try_load_aoti_from_hub(model):
        results["hub_loaded"] = True
        results["compile_success"] = True
        print("[AoTI] Using pre-compiled model from Hub")
        return results

    # No cached model found - compile fresh
    print("[AoTI] No cached model on Hub, compiling fresh...")

    # Convert config minutes to samples (16kHz audio)
    SAMPLES_PER_MINUTE = 16000 * 60
    min_samples = int(AOTI_MIN_AUDIO_MINUTES * SAMPLES_PER_MINUTE)
    max_samples = int(AOTI_MAX_AUDIO_MINUTES * SAMPLES_PER_MINUTE)

    # Create test audio for capture - use min duration to save memory
    # MUST be on CPU - segment_recitations moves to GPU internally
    test_audio = torch.randn(min_samples, device="cpu")
    print(f"[AoTI] Test audio: {min_samples} samples ({AOTI_MIN_AUDIO_MINUTES} min)")

    # Capture the exact args/kwargs used by segment_recitations
    try:
        from recitations_segmenter import segment_recitations

        print("[AoTI] Capturing call signature via aoti_capture...")
        with spaces.aoti_capture(model) as call:
            segment_recitations(
                [test_audio], model, processor,
                device=device, dtype=dtype, batch_size=1,
            )

        print(f"[AoTI] Captured args: {len(call.args)} positional, {list(call.kwargs.keys())} kwargs")

    except Exception as e:
        results["error"] = f"aoti_capture failed: {type(e).__name__}: {e}"
        print(f"[AoTI] {results['error']}")
        import traceback
        traceback.print_exc()
        return results

    # Build dynamic shapes from captured tensors
    # The sequence dimension (T) varies with audio length
    try:
        from torch.export import export, Dim

        # Derive frame rate from captured tensor (model's actual output rate)
        # Find the first 2D+ tensor to get the captured frame count
        captured_frames = None
        for val in list(call.kwargs.values()) + list(call.args):
            if isinstance(val, torch.Tensor) and val.dim() >= 2:
                captured_frames = val.shape[1]
                break

        if captured_frames is None:
            raise ValueError("No 2D+ tensor found in captured args/kwargs")

        # Calculate frames per minute from captured data
        frames_per_minute = captured_frames / AOTI_MIN_AUDIO_MINUTES
        min_frames = captured_frames  # Already at min duration
        max_frames = int(AOTI_MAX_AUDIO_MINUTES * frames_per_minute)
        dynamic_T = Dim("T", min=min_frames, max=max_frames)
        print(f"[AoTI] Captured {captured_frames} frames for {AOTI_MIN_AUDIO_MINUTES} min = {frames_per_minute:.1f} frames/min")
        print(f"[AoTI] Dynamic shape range: {min_frames}-{max_frames} frames")

        # Build dynamic_shapes dict matching the captured signature
        dynamic_shapes_args = []
        for arg in call.args:
            if isinstance(arg, torch.Tensor) and arg.dim() >= 2:
                # Assume sequence dim is dim 1 for 2D+ tensors
                dynamic_shapes_args.append({1: dynamic_T})
            else:
                dynamic_shapes_args.append(None)

        dynamic_shapes_kwargs = {}
        for key, val in call.kwargs.items():
            if isinstance(val, torch.Tensor) and val.dim() >= 2:
                dynamic_shapes_kwargs[key] = {1: dynamic_T}
            else:
                dynamic_shapes_kwargs[key] = None

        print(f"[AoTI] Dynamic shapes - args: {dynamic_shapes_args}, kwargs: {list(dynamic_shapes_kwargs.keys())}")

        t0 = time.time()
        # Export using captured signature - guarantees match with inference
        exported = export(
            model,
            args=call.args,
            kwargs=call.kwargs,
            dynamic_shapes=(dynamic_shapes_args, dynamic_shapes_kwargs) if dynamic_shapes_args else dynamic_shapes_kwargs,
            strict=False,
        )
        results["export_time"] = time.time() - t0
        results["export_success"] = True
        _aoti_cache["exported"] = exported
        print(f"[AoTI] torch.export SUCCESS in {results['export_time']:.1f}s")

    except Exception as e:
        results["error"] = f"torch.export failed: {type(e).__name__}: {e}"
        print(f"[AoTI] {results['error']}")
        import traceback
        traceback.print_exc()
        return results

    # Attempt spaces.aoti_compile
    try:
        t0 = time.time()
        compiled = spaces.aoti_compile(exported)
        results["compile_time"] = time.time() - t0
        results["compile_success"] = True
        _aoti_cache["compiled"] = compiled
        print(f"[AoTI] spaces.aoti_compile SUCCESS in {results['compile_time']:.1f}s")

        # Return compiled object - apply happens OUTSIDE GPU lease (in main process)
        results["compiled"] = compiled
        print(f"[AoTI] Compiled object ready for apply")

        # Upload to Hub for future reuse
        if _push_aoti_to_hub(compiled):
            results["hub_uploaded"] = True

    except Exception as e:
        results["error"] = f"aoti_compile failed: {type(e).__name__}: {e}"
        print(f"[AoTI] {results['error']}")
        import traceback
        traceback.print_exc()

    return results


def apply_aoti_compiled(compiled):
    """
    Apply AoTI compiled model to VAD segmenter.
    Must be called OUTSIDE GPU lease, in main process.
    """
    if compiled is None:
        print("[AoTI] No compiled object to apply")
        return False

    model = _segmenter_cache.get("model")
    if model is None:
        print("[AoTI] Model not loaded, cannot apply")
        return False

    try:
        import spaces
        spaces.aoti_apply(compiled, model)
        _aoti_cache["compiled"] = compiled
        _aoti_cache["applied"] = True
        print(f"[AoTI] Compiled model applied to VAD (model_id={id(model)})")
        return True
    except Exception as e:
        print(f"[AoTI] Apply failed: {e}")
        return False