File size: 43,659 Bytes
8c93973
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
"""
BlockSwap Module for SeedVR2

This module implements dynamic block swapping between GPU and CPU memory
to enable running large models on limited VRAM systems.

Key Features:
- Dynamic transformer block offloading during inference
- Non-blocking GPU transfers for optimal performance
- RoPE computation fallback to CPU on OOM
- Minimal performance overhead with intelligent caching
- I/O component offloading for maximum memory savings
"""

import time
import types
import torch
import weakref

from typing import Dict, Any, List, Optional
from .memory_manager import clear_memory
from .compatibility import call_rope_with_stability
from ..common.distributed import get_device


def is_blockswap_enabled(config: Optional[Dict[str, Any]]) -> bool:
    """
    Check if BlockSwap configuration indicates BlockSwap should be enabled.
    
    BlockSwap is enabled if either blocks_to_swap > 0 OR swap_io_components is True.
    This is the authoritative function for determining BlockSwap status from configuration.
    
    Args:
        config: BlockSwap configuration dictionary with optional keys:
            - blocks_to_swap: Number of blocks to offload (0 = disabled)
            - swap_io_components: Whether to offload I/O components
        
    Returns:
        True if BlockSwap should be active, False otherwise
    """
    if not config:
        return False
    
    blocks_to_swap = config.get("blocks_to_swap", 0)
    swap_io_components = config.get("swap_io_components", False)
    
    return blocks_to_swap > 0 or swap_io_components


def validate_blockswap_config(
    block_swap_config: Optional[Dict[str, Any]],
    dit_device: 'torch.device',
    dit_offload_device: Optional['torch.device'],
    debug: 'Debug'
) -> Optional[Dict[str, Any]]:
    """
    Validate and potentially modify BlockSwap configuration.
    
    Performs platform-specific validation and configuration adjustment:
    - On macOS (MPS): Auto-disables BlockSwap since unified memory makes it meaningless
    - On other platforms: Validates that offload_device is properly configured
    
    This is the single authoritative validation point for BlockSwap configuration,
    called early in configure_runner() before any model loading.
    
    Args:
        block_swap_config: BlockSwap configuration dictionary (may be None)
        dit_device: Target device for DiT model inference
        dit_offload_device: Device for offloading DiT blocks (may be None)
        debug: Debug instance for logging warnings/errors
        
    Returns:
        Validated/modified block_swap_config (may be None or modified copy)
        
    Raises:
        ValueError: If BlockSwap is enabled but offload_device is invalid (non-MPS only)
    """
    if not is_blockswap_enabled(block_swap_config):
        return block_swap_config
    
    blocks_to_swap = block_swap_config.get("blocks_to_swap", 0)
    swap_io_components = block_swap_config.get("swap_io_components", False)
    
    # Check for macOS unified memory - BlockSwap is meaningless there
    if dit_device.type == "mps":
        debug.log(
            f"BlockSwap disabled: macOS uses unified memory (no separate VRAM/RAM). "
            f"Ignoring blocks_to_swap={blocks_to_swap}, swap_io_components={swap_io_components}",
            level="WARNING", category="blockswap", force=True
        )
        # Return disabled config
        return {
            **block_swap_config,
            "blocks_to_swap": 0,
            "swap_io_components": False
        }
    
    # Validate offload_device is set and different from dit_device
    offload_device_valid = (
        dit_offload_device is not None and 
        str(dit_offload_device) != str(dit_device)
    )
    
    if not offload_device_valid:
        config_details = []
        if blocks_to_swap > 0:
            config_details.append(f"blocks_to_swap={blocks_to_swap}")
        if swap_io_components:
            config_details.append("swap_io_components=True")
        
        offload_str = str(dit_offload_device) if dit_offload_device else "none"
        raise ValueError(
            f"BlockSwap enabled ({', '.join(config_details)}) but dit_offload_device is invalid. "
            f"Current: device='{dit_device}', dit_offload_device='{offload_str}'. "
            f"BlockSwap requires offload_device on the DiT Model to be set and different from device. "
            f"Set --dit_offload_device cpu or disable BlockSwap."
        )
    
    return block_swap_config


# Timing helpers marked to skip torch.compile tracing
# These functions are excluded from Dynamo's graph tracing to avoid warnings
# about non-traceable builtins like time.time(), but they still execute normally
@torch._dynamo.disable
def _get_swap_start_time(debug, enabled: bool) -> Optional[float]:
    """Get start time for swap operation if debug is enabled."""
    return time.time() if debug and enabled else None


@torch._dynamo.disable  
def _log_swap_timing(debug, t_start: Optional[float], component_id, component_type: str) -> None:
    """Log swap timing if start time was captured."""
    if debug and t_start is not None:
        debug.log_swap_time(
            component_id=component_id,
            duration=time.time() - t_start,
            component_type=component_type
        )


def get_module_memory_mb(module: torch.nn.Module) -> float:
    """
    Calculate memory usage of a module in MB.
    
    Args:
        module: PyTorch module to measure
        
    Returns:
        Memory usage in megabytes
    """
    total_bytes = sum(
        param.nelement() * param.element_size() 
        for param in module.parameters() 
        if param.data is not None
    )
    return total_bytes / (1024 * 1024)


def apply_block_swap_to_dit(
    runner: 'VideoDiffusionInfer',
    block_swap_config: Dict[str, Any],
    debug: 'Debug'
) -> None:
    """
    Apply block swapping configuration to a DiT model with OOM protection.
    
    This is the main entry point for configuring block swapping on a model.
    Handles block selection, I/O component offloading, device placement, and
    forward method wrapping for dynamic memory management.
    
    Args:
        runner: VideoDiffusionInfer instance containing the model
        block_swap_config: Configuration dictionary with keys:
            - blocks_to_swap: Number of blocks to swap (from the start)
            - swap_io_components: Whether to offload I/O components  
            - enable_debug: Whether to enable debug logging
            - offload_device: Device to offload to (default: 'cpu')
        debug: Debug instance for logging (required)
    """
    # Early return if BlockSwap not enabled
    if not is_blockswap_enabled(block_swap_config):
        return

    blocks_to_swap = block_swap_config.get("blocks_to_swap", 0)
    swap_io_components = block_swap_config.get("swap_io_components", False)
    
    # Early return only if both block swap and I/O swap are disabled
    if blocks_to_swap <= 0 and not swap_io_components:
        return
    
    if debug is None:
        if hasattr(runner, 'debug') and runner.debug is not None:
            debug = runner.debug
        else:
            raise ValueError("Debug instance must be provided to apply_block_swap_to_dit")
    
    debug.start_timer("apply_blockswap")

    # Get the actual model (handle CompatibleDiT wrapper)
    model = runner.dit
    if hasattr(model, "dit_model"):
        model = model.dit_model
    
    # Determine devices
    if hasattr(runner, '_dit_device'):
        device = runner._dit_device
    else:
        device = get_device()
    offload_device = block_swap_config.get("offload_device", torch.device('cpu'))

    # Validate model structure
    if not hasattr(model, "blocks"):
        debug.log("Model doesn't have 'blocks' attribute for BlockSwap", level="ERROR", category="blockswap", force=True)
        return

    total_blocks = len(model.blocks)
    
    # Clamp blocks_to_swap to available blocks BEFORE logging
    effective_blocks = min(blocks_to_swap, total_blocks) if blocks_to_swap > 0 else 0
    
    # Log configuration clearly based on what's enabled
    block_text = "block" if effective_blocks <= 1 else "blocks"
    if effective_blocks > 0 and swap_io_components:
        debug.log(f"BlockSwap: {effective_blocks}/{total_blocks} transformer {block_text} + I/O components offloaded to {str(offload_device).upper()}", category="blockswap", force=True)
    elif effective_blocks > 0:
        debug.log(f"BlockSwap: {effective_blocks}/{total_blocks} transformer {block_text} offloaded to {str(offload_device).upper()}", category="blockswap", force=True)
    elif swap_io_components:
        debug.log(f"BlockSwap: I/O components offloaded to {str(offload_device).upper()} (0/{total_blocks} blocks swapped)", category="blockswap", force=True)
    
    # Configure model with blockswap attributes
    if blocks_to_swap > 0:
        model.blocks_to_swap = effective_blocks - 1  # Convert to 0-indexed
    else:
        # No block swapping, set to -1 so no blocks match the swap condition
        model.blocks_to_swap = -1
    
    model.main_device = device
    model.offload_device = offload_device

    # Configure I/O components
    io_config = _configure_io_components(model, device, offload_device, 
                                        swap_io_components, debug)
    memory_stats = _configure_blocks(model, device, offload_device, debug)
    memory_stats['io_components'] = io_config['components']
    memory_stats['io_memory_mb'] = io_config['memory_mb']
    memory_stats['gpu_components'] = io_config['gpu_components']
    memory_stats['io_gpu_memory_mb'] = io_config['gpu_memory_mb']

    # Log memory summary
    _log_memory_summary(memory_stats, offload_device, device, swap_io_components, 
                       debug)

    # Initialize Nunchaku-style async management object
    if blocks_to_swap > 0:
        # normalize device objects
        if isinstance(device, str):
            device = torch.device(device)
        model._swap_stream = torch.cuda.Stream(device=device)
        model._block_ready_events = {}

        # Preload first swapped block to seed pipeline (non-blocking on swap_stream)
        try:
            first_idx = 0
            if first_idx <= model.blocks_to_swap:
                with torch.cuda.stream(model._swap_stream):
                    model.blocks[first_idx].to(device, non_blocking=True)
                    ev = torch.cuda.Event(blocking=False)
                    ev.record(model._swap_stream) # record on swap_stream -> event gets device-bound here
                    model._block_ready_events[first_idx] = ev
        except Exception as e:
            debug.log(f"Failed to initialize swap-stream prefetch: {e}", level="WARNING", category="blockswap", force=True)


    # Wrap block forward methods for dynamic swapping (only if blocks_to_swap > 0)
    if blocks_to_swap > 0:
        for b, block in enumerate(model.blocks):
            if b <= model.blocks_to_swap:
                _wrap_block_forward(block, b, model, debug)

    # Patch RoPE modules for robust error handling
    _patch_rope_for_blockswap(model, debug)

    # Mark BlockSwap as active
    runner._blockswap_active = True

    # Store configuration for debugging and cleanup
    model._block_swap_config = {
        "blocks_swapped": blocks_to_swap,
        "swap_io_components": swap_io_components,
        "total_blocks": total_blocks,
        "offload_device": offload_device,
        "main_device": device,
        "offload_memory": memory_stats['offload_memory'],
        "main_memory": memory_stats['main_memory']
    }

    # Protect model from being moved entirely
    _protect_model_from_move(model, runner, debug)

    debug.log("BlockSwap configuration complete", category="success")
    debug.end_timer("apply_blockswap", "BlockSwap configuration application")
    

def _configure_io_components(
    model: torch.nn.Module,
    device: torch.device,
    offload_device: torch.device,
    swap_io_components: bool,
    debug: 'Debug'
) -> Dict[str, Any]:
    """
    Configure I/O component placement and wrapping with memory tracking.
    
    Handles all non-block modules (embeddings, normalization layers, etc.) by
    either keeping them on GPU or offloading them with dynamic swapping wrappers.
    
    Args:
        model: DiT model containing named children to configure
        device: Main computation device (typically GPU)
        offload_device: Device for offloaded components (typically CPU)
        swap_io_components: If True, offload I/O components with dynamic swapping
        debug: Debug instance for logging (required)
        
    Returns:
        Dictionary containing:
            - components: List of offloaded component names
            - memory_mb: Total memory of offloaded components in MB
            - gpu_components: List of components remaining on GPU
            - gpu_memory_mb: Total memory of GPU components in MB
    """
    io_components_offloaded = []
    io_components_on_gpu = []
    io_memory_mb = 0.0
    io_gpu_memory_mb = 0.0
    
    # Check for pin memory condition
    use_pin_memory = (offload_device == "cpu") if isinstance(offload_device, str) else (offload_device.type == "cpu")

    # Handle I/O modules with dynamic swapping
    for name, module in model.named_children():
        if name != "blocks":
            module_memory = get_module_memory_mb(module)
            
            if swap_io_components:
                module.to(offload_device)
                
                # Enable Pin Memory for I/O components
                if use_pin_memory:
                    for p in module.parameters():
                        if not p.is_pinned():
                            p.data = p.data.pin_memory()
                    for buf in module.buffers():
                        if not buf.is_pinned():
                            buf.data = buf.data.pin_memory()

                _wrap_io_forward(module, name, model, debug)
                io_components_offloaded.append(name)
                io_memory_mb += module_memory
                debug.log(f"{name}{str(offload_device).upper()} ({module_memory:.2f}MB, dynamic swapping)", category="blockswap", indent_level=1)
            else:
                module.to(device)
                io_components_on_gpu.append(name)
                io_gpu_memory_mb += module_memory
                debug.log(f"{name}{str(device).upper()} ({module_memory:.2f}MB)", category="blockswap", indent_level=1)

    return {
        'components': io_components_offloaded,
        'memory_mb': io_memory_mb,
        'gpu_components': io_components_on_gpu,
        'gpu_memory_mb': io_gpu_memory_mb
    }


def _configure_blocks(
    model: torch.nn.Module,
    device: torch.device,
    offload_device: torch.device,
    debug: 'Debug'
) -> Dict[str, float]:
    """
    Configure transformer block placement and calculate memory statistics.
    
    Moves blocks to their designated devices based on model.blocks_to_swap
    attribute. Blocks with index <= blocks_to_swap go to offload device,
    others stay on main device.
    
    Args:
        model: DiT model with blocks attribute and blocks_to_swap configured
        device: Main computation device for non-swapped blocks
        offload_device: Device for swapped blocks  
        debug: Debug instance for logging (required)
        
    Returns:
        Dictionary containing:
            - offload_memory: Total memory of offloaded blocks in MB
            - main_memory: Total memory of blocks on main device in MB
            - io_components: Empty list (populated by caller)
    """
    total_offload_memory = 0.0
    total_main_memory = 0.0
    
    # Check if we should pin memory (if offloading to CPU)
    # Nunchaku uses pinned memory for faster async transfers
    use_pin_memory = (offload_device == "cpu") if isinstance(offload_device, str) else (offload_device.type == "cpu")

    # Move blocks based on swap configuration
    for b, block in enumerate(model.blocks):
        block_memory = get_module_memory_mb(block)

        if b > model.blocks_to_swap:
            block.to(device)
            total_main_memory += block_memory
        else:
            block.to(offload_device, non_blocking=False)
            total_offload_memory += block_memory
            
            # Enable Pin Memory optimization for CPU Offload transfer speed
            if use_pin_memory:
                for p in block.parameters():
                    if not p.is_pinned():
                        p.data = p.data.pin_memory()
                for buf in block.buffers():
                    if not buf.is_pinned():
                        buf.data = buf.data.pin_memory()
                        
    # Ensure all buffers match their containing module's device
    for b, block in enumerate(model.blocks):
        target_device = device if b > model.blocks_to_swap else offload_device
        for name, buffer in block.named_buffers():
            if buffer.device != torch.device(target_device):
                # Apply pinning if needed
                if use_pin_memory and target_device.type == "cpu" and not buffer.is_pinned():
                    buffer.data = buffer.data.pin_memory()
                buffer.data = buffer.data.to(target_device, non_blocking=False)

    return {
        "offload_memory": total_offload_memory,
        "main_memory": total_main_memory,
        "io_components": []  # Will be populated by caller
    }


def _log_memory_summary(
    memory_stats: Dict[str, float],
    offload_device: torch.device,
    device: torch.device,
    swap_io_components: bool,
    debug: 'Debug'
) -> None:
    """
    Log comprehensive memory usage summary for BlockSwap configuration.
    
    Displays detailed breakdown of memory distribution across devices,
    including transformer blocks and I/O components.
    
    Args:
        memory_stats: Dictionary containing:
            - offload_memory: Memory offloaded from blocks (MB)
            - main_memory: Memory remaining on main device (MB)
            - io_memory_mb: Memory from offloaded I/O components (MB)
            - io_gpu_memory_mb: Memory from I/O components on GPU (MB)
        offload_device: Device used for offloading
        device: Main computation device
        swap_io_components: Whether I/O components are being swapped
        debug: Debug instance for logging (required)
    """
    debug.log("BlockSwap memory configuration:", category="blockswap")
    
    # Log transformer blocks memory
    blocks_offloaded = memory_stats['offload_memory']
    blocks_on_gpu = memory_stats['main_memory']
    
    offload_str = str(offload_device)
    device_str = str(device)
    
    if blocks_on_gpu == 0:
        debug.log(f"Transformer blocks: {blocks_offloaded:.2f}MB on {offload_str} (dynamic swapping)", category="blockswap", indent_level=1)
    else:
        debug.log(f"Transformer blocks: {blocks_on_gpu:.2f}MB on {device_str}, {blocks_offloaded:.2f}MB on {offload_str}", category="blockswap", indent_level=1)
    
    # Always log I/O components (whether swapping or not)
    io_memory = memory_stats.get('io_memory_mb', 0.0)
    io_gpu_memory = memory_stats.get('io_gpu_memory_mb', 0.0)
    
    if swap_io_components and io_memory > 0:
        io_components = memory_stats.get('io_components', [])
        debug.log(f"I/O components: {io_memory:.2f}MB on {offload_str} (dynamic swapping)", category="blockswap", indent_level=1)
        debug.log(f"{', '.join(io_components)}", category="blockswap", indent_level=2)
    elif io_gpu_memory > 0:
        io_gpu_components = memory_stats.get('gpu_components', [])
        debug.log(f"I/O components: {io_gpu_memory:.2f}MB on {device_str}", category="blockswap", indent_level=1)
        debug.log(f"{', '.join(io_gpu_components)}", category="blockswap", indent_level=2)
    
    # Log total VRAM savings
    total_offloaded = blocks_offloaded + (io_memory if swap_io_components else 0)
    if total_offloaded > 0:
        debug.log(f"Total VRAM saved: {total_offloaded:.2f}MB (~{total_offloaded/1024:.2f}GB)", category="blockswap", indent_level=1)


def _wrap_block_forward(
    block: torch.nn.Module,
    block_idx: int,
    model: torch.nn.Module,
    debug: 'Debug'
) -> None:
    """
    Wrap individual transformer block forward for dynamic device swapping.

    Implements Nunchaku-style pipelining: Prefetch Next -> Compute Current -> Offload Current.
    https://github.com/nunchaku-tech/nunchaku/blob/main/nunchaku/models/utils.py
    
    Creates a wrapped forward method that automatically:
    1. Moves block to GPU before computation
    2. Executes original forward pass
    3. Moves block back to offload device after computation
    4. Logs timing and manages memory pressure
    
    Uses weak references to prevent memory leaks from closure retention.
    
    Args:
        block: Individual transformer block to wrap
        block_idx: Index of this block in model.blocks
        model: Parent DiT model (used for device references)
        debug: Debug instance for logging (required)
    """
    if hasattr(block, '_original_forward'):
        return  # Already wrapped

    # Store original forward method
    original_forward = block.forward
    
    # Create weak references
    model_ref = weakref.ref(model)
    debug_ref = weakref.ref(debug) if debug is not None else (lambda: None)
    
    # Store block_idx on the block itself to avoid closure issues
    block._block_idx = block_idx
    
    def wrapped_forward(self, *args, **kwargs):
        # Retrieve weak references
        model = model_ref()
        debug = debug_ref()
        
        if not model:
            # Model has been garbage collected, fall back to original
            return original_forward(*args, **kwargs)

        # Check if block swap is active for this block
        if hasattr(model, 'blocks_to_swap') and self._block_idx <= model.blocks_to_swap:
            # Use dynamo-disabled helper to get start time (avoids compilation warnings)
            t_start = _get_swap_start_time(debug, debug.enabled if debug else False)
            
            # Only move to GPU if necessary
            current_device = next(self.parameters()).device
            target_device = torch.device(model.main_device)
            
            # 1. Ensure CURRENT block is ready on GPU
            # Check if we have a prefetch event waiting
            if hasattr(model, '_block_ready_events') and self._block_idx in model._block_ready_events:
                # Wait for the swap stream to finish moving this block
                torch.cuda.current_stream().wait_event(model._block_ready_events[self._block_idx])
                # Cleanup event
                del model._block_ready_events[self._block_idx]
            elif current_device != target_device:
                # Fallback: First block or missed prefetch, move synchronously (but non-blocking)
                debug.log(f"[blockswap] Block {self._block_idx} missing prefetch event, moving synchronously", level="WARNING", category="blockswap", force=True)
                self.to(model.main_device, non_blocking=True)
            
            # 2. Trigger Prefetch for NEXT block (Pipelining)
            # Nunchaku logic: Start moving i+1 while i is computing
            next_idx = self._block_idx + 1
            if next_idx <= model.blocks_to_swap:
                next_block = model.blocks[next_idx]
                # Use the dedicated swap stream
                with torch.cuda.stream(model._swap_stream):
                    next_block.to(model.main_device, non_blocking=True)
                    # Record event so next iteration knows when to wait
                    event = torch.cuda.Event(blocking=False)
                    event.record(model._swap_stream)
                    model._block_ready_events[next_idx] = event

            # 3. Execute forward pass (Compute)
            # This runs on the default stream, overlapping with the prefetch above
            output = original_forward(*args, **kwargs)

            # 4. Offload CURRENT block (Async)
            # We record an event on compute stream to ensure we don't move data while it's being used
            compute_done_event = torch.cuda.Event(blocking=False)
            compute_done_event.record(torch.cuda.current_stream())
            
            with torch.cuda.stream(model._swap_stream):
                # Wait for compute to finish before moving memory out
                model._swap_stream.wait_event(compute_done_event)
                # Move back to offload device
                self.to(model.offload_device, non_blocking=True)

            # Use dynamo-disabled helper to log timing (avoids compilation warnings)
            _log_swap_timing(debug, t_start, self._block_idx, "block (pipelined)")

            # Only clear cache under memory pressure
            clear_memory(debug=debug, deep=False, force=False, timer_name="wrap_block_forward")
        else:
            output = original_forward(*args, **kwargs)

        return output
    
    # Bind the wrapped function as a method to the block
    block.forward = types.MethodType(wrapped_forward, block)
    
    # Store reference to original forward for cleanup
    block._original_forward = original_forward


def _wrap_io_forward(
    module: torch.nn.Module,
    module_name: str,
    model: torch.nn.Module,
    debug: 'Debug'
) -> None:
    """
    Wrap I/O component forward for dynamic device swapping.
    
    Similar to _wrap_block_forward but for I/O components (embeddings,
    normalization layers, etc.). Handles swapping between GPU and CPU
    during forward passes.
    
    Uses weak references to prevent circular dependencies and memory leaks.
    
    Args:
        module: I/O component module to wrap
        module_name: Name identifier for logging (e.g., 'x_embedder')
        model: Parent DiT model (used for device references)
        debug: Debug instance for logging (required)
    """
    if hasattr(module, '_is_io_wrapped') and module._is_io_wrapped:
        debug.log(f"Reusing existing I/O wrapper for {module_name}", category="reuse")
        return  # Already wrapped

    # Store original forward method
    original_forward = module.forward
    
    # Create weak references
    model_ref = weakref.ref(model)
    debug_ref = weakref.ref(debug) if debug else lambda: None
    
    # Store module name on the module itself
    module._module_name = module_name
    module._original_forward = original_forward
    
    def wrapped_io_forward(self, *args, **kwargs):
        # Retrieve weak references
        model = model_ref()
        debug = debug_ref()
        
        if not model:
            # Model has been garbage collected, fall back to original
            return self._original_forward(*args, **kwargs)

        # Use dynamo-disabled helper to get start time (avoids compilation warnings)
        t_start = _get_swap_start_time(debug, debug.enabled if debug else False)
        
        # Check current device to avoid unnecessary moves
        current_device = next(self.parameters()).device
        target_device = torch.device(model.main_device)
        
        # Move to GPU for computation if needed
        if current_device != target_device:
            self.to(model.main_device, non_blocking=False)

        # Execute forward pass
        output = self._original_forward(*args, **kwargs)

        # Move back to offload device
        self.to(model.offload_device, non_blocking=False)
        
        # Use dynamo-disabled helper to log timing (avoids compilation warnings)
        _log_swap_timing(debug, t_start, self._module_name, "I/O")

        # Only clear cache under memory pressure
        clear_memory(debug=debug, deep=False, force=False, timer_name="wrap_block_forward")

        return output
    
    # Bind as a method
    module.forward = types.MethodType(wrapped_io_forward, module)
    module._is_io_wrapped = True
    
    # Store module reference for restoration
    if not hasattr(model, '_io_swappers'):
        model._io_swappers = []
    model._io_swappers.append((module, module_name))


def _patch_rope_for_blockswap(
    model: torch.nn.Module,
    debug: 'Debug'
) -> None:
    """
    Patch RoPE (Rotary Position Embedding) modules for device-aware fallback.
    
    Adds CPU fallback logic to RoPE modules to handle device mismatch errors
    that can occur during BlockSwap operations. Complements the stability
    wrapper from compatibility.py with device-specific error handling.
    
    Args:
        model: DiT model containing RoPE modules to patch
        debug: Debug instance for logging (required)
    """
    rope_patches = []
    
    for name, module in model.named_modules():
        if "rope" in name.lower() and hasattr(module, "get_axial_freqs"):
            # Skip if already wrapped by blockswap
            if hasattr(module, '_blockswap_wrapped') and module._blockswap_wrapped:
                continue
            
            # Get current method (might be stability-wrapped)
            current_method = module.get_axial_freqs
            
            # Create device-aware wrapper with proper closure handling
            def make_device_aware_wrapper(module_name, current_fn):
                def device_aware_rope_wrapper(self, *args, **kwargs):
                    try:
                        # Try current method (original or stability-wrapped)
                        return current_fn(*args, **kwargs)
                    except (RuntimeError, torch.cuda.OutOfMemoryError) as e:
                        error_msg = str(e).lower()
                        # Only handle device/memory specific errors
                        if any(x in error_msg for x in ["device", "memory", "allocation"]):
                            debug.log(f"RoPE OOM for {module_name}", level="WARNING", category="rope", force=True)
                            debug.log(f"Clearing RoPE cache and retrying", category="info", force=True)
                            
                            # Get current device from parameters
                            try:
                                current_device = next(self.parameters()).device
                            except StopIteration:
                                # Fallback: use model's main_device if BlockSwap has set it, else use offload_device
                                if hasattr(model, 'main_device'):
                                    current_device = torch.device(model.main_device)
                                elif hasattr(model, 'offload_device'):
                                    current_device = torch.device(model.offload_device)
                            
                            # Try clearing cache first (non-invasive fix)
                            if hasattr(current_fn, 'cache_clear'):
                                current_fn.cache_clear()
                                try:
                                    # Retry on same device after clearing cache
                                    return current_fn(*args, **kwargs)
                                except Exception as retry_error:
                                    # Cache clear wasn't enough, need more drastic measures
                                    debug.log(f"Cache clear insufficient for {module_name}, falling back to CPU", level="WARNING", category="rope", force=True)
                            
                            # Fallback to CPU computation with stability
                            self.cpu()
                            
                            try:
                                # Use call_rope_with_stability for CPU computation
                                # This ensures cache is cleared and autocast disabled
                                original_fn = getattr(self, '_original_get_axial_freqs', current_fn)
                                result = call_rope_with_stability(original_fn, *args, **kwargs)
                                
                                # Move module back to original device
                                self.to(current_device)
                                
                                # Move result to appropriate device if it's a tensor
                                if hasattr(result, 'to'):
                                    target_device = args[0].device if len(args) > 0 and hasattr(args[0], 'device') else current_device
                                    return result.to(target_device)
                                return result
                                
                            except Exception as cpu_error:
                                # Always restore device even on error
                                self.to(current_device)
                                raise cpu_error
                        else:
                            # Not a device error, let it bubble up
                            raise
                
                return device_aware_rope_wrapper
            
            # Apply wrapper
            module.get_axial_freqs = types.MethodType(
                make_device_aware_wrapper(name, current_method), 
                module
            )
            module._blockswap_wrapped = True
            
            # Store for cleanup (use original or previously stored)
            original_method = getattr(module, '_original_get_axial_freqs', current_method)
            rope_patches.append((module, original_method))
    
    if rope_patches:
        model._rope_patches = rope_patches
        debug.log(f"Patched {len(rope_patches)} RoPE modules with device handling", category="success")


def _protect_model_from_move(
    model: torch.nn.Module,
    runner: 'VideoDiffusionInfer',
    debug: 'Debug'
) -> None:
    """
    Protect model from unintended full device movement during BlockSwap.
    
    Wraps model.to() method to prevent other code from accidentally moving
    the entire model to GPU, which would defeat BlockSwap's memory savings.
    Allows movement only when explicitly bypassed via model flag.
    
    Args:
        model: DiT model to protect
        runner: VideoDiffusionInfer instance (for active status check)
        debug: Debug instance for logging (required)
    """
    if not hasattr(model, '_original_to'):
        # Store runner reference as weak reference to avoid circular refs
        model._blockswap_runner_ref = weakref.ref(runner)
        model._original_to = model.to
        
        # Define the protected method without closures
        def protected_model_to(self, device, *args, **kwargs):
            # Check if protection is temporarily bypassed for offloading
            # Flag is stored on model itself (not runner) to survive runner recreation
            if getattr(self, "_blockswap_bypass_protection", False):
                # Protection bypassed, allow movement
                if hasattr(self, '_original_to'):
                    return self._original_to(device, *args, **kwargs)
            
            # Get configured offload device directly from model
            blockswap_offload_device = "cpu"  # default
            if hasattr(self, "_block_swap_config"):
                blockswap_offload_device = self._block_swap_config.get("offload_device", "cpu")
            
            # Check if BlockSwap is currently active via runner weak reference
            runner_ref = getattr(self, '_blockswap_runner_ref', None)
            blockswap_is_active = False
            if runner_ref:
                runner_obj = runner_ref()
                if runner_obj and hasattr(runner_obj, "_blockswap_active"):
                    blockswap_is_active = runner_obj._blockswap_active
            
            # Block attempts to move model away from configured offload device when active
            if blockswap_is_active and str(device) != str(blockswap_offload_device):
                # Get debug instance from runner if available
                debug_instance = None
                if runner_ref:
                    runner_obj = runner_ref()
                    if runner_obj and hasattr(runner_obj, 'debug'):
                        debug_instance = runner_obj.debug
                
                if debug_instance:
                    debug_instance.log(
                        f"Blocked attempt to move BlockSwap model from {blockswap_offload_device} to {device}",
                        level="WARNING", category="blockswap", force=True
                    )
                return self
            
            # Allow movement (either bypass is enabled or target is offload device)
            if hasattr(self, '_original_to'):
                return self._original_to(device, *args, **kwargs)
            else:
                # Fallback - shouldn't happen
                return super(type(self), self).to(device, *args, **kwargs)
        
        # Bind as a method to the model instance
        model.to = types.MethodType(protected_model_to, model)


def set_blockswap_bypass(runner, bypass: bool, debug):
    """
    Set or unset bypass flag for BlockSwap protection.
    Used for offloading to temporarily allow model movement.
    
    Args:
        runner: Runner instance with BlockSwap
        bypass: True to bypass protection, False to enforce it
        debug: Debug instance for logging
    """
    if not hasattr(runner, "_blockswap_active") or not runner._blockswap_active:
        return
    
    # Get the actual model (handle CompatibleDiT wrapper)
    model = runner.dit
    if hasattr(model, "dit_model"):
        model = model.dit_model
    
    # Store on model so it survives runner recreation during caching
    model._blockswap_bypass_protection = bypass
    
    if bypass:
        debug.log("BlockSwap protection disabled to allow model DiT offloading", category="success")
    else:
        debug.log("BlockSwap protection renabled to avoid accidentally offloading the entire DiT model", category="success")


def cleanup_blockswap(runner, keep_state_for_cache=False):
    """
    Clean up BlockSwap configuration based on caching mode.
    
    When caching (keep_state_for_cache=True):
    - Keep all BlockSwap configuration intact
    - Only mark as inactive for safety during non-inference operations
    
    When not caching (keep_state_for_cache=False):
    - Full cleanup of all BlockSwap state
    
    Args:
        runner: VideoDiffusionInfer instance to clean up
        keep_state_for_cache: If True, preserve BlockSwap state for reuse
    """
    # Get debug instance from runner
    if not hasattr(runner, 'debug') or runner.debug is None:
        raise ValueError("Debug instance must be available on runner for cleanup_blockswap")
    
    debug = runner.debug
    
    # Get the actual model (handle CompatibleDiT wrapper)
    model = runner.dit
    if hasattr(model, "dit_model"):
        model = model.dit_model
    
    # Check if there's any BlockSwap state to clean up (check both runner and model)
    has_blockswap_state = (
        hasattr(runner, "_blockswap_active") or 
        hasattr(model, "_block_swap_config") or
        hasattr(model, "_blockswap_bypass_protection")
    )
    
    if not has_blockswap_state:
        return

    debug.log("Starting BlockSwap cleanup", category="cleanup")

    if keep_state_for_cache:
        # Minimal cleanup for caching - just mark as inactive and allow offloading
        # Everything else stays intact for fast reactivation
        if hasattr(runner, "_blockswap_active") and runner._blockswap_active:
            if not getattr(model, "_blockswap_bypass_protection", False):
                set_blockswap_bypass(runner=runner, bypass=True, debug=debug)
            runner._blockswap_active = False
        debug.log("BlockSwap deactivated for caching (configuration preserved)", category="success")
        return

    # Full cleanup when not caching
    # Get the actual model (handle CompatibleDiT wrapper)
    model = runner.dit
    if hasattr(model, "dit_model"):
        model = model.dit_model

    # 1. Restore block forward methods
    if hasattr(model, 'blocks'):
        restored_count = 0
        for block in model.blocks:
            if hasattr(block, '_original_forward'):
                block.forward = block._original_forward
                delattr(block, '_original_forward')
                restored_count += 1
                
                # Clean up wrapper attributes
                for attr in ['_block_idx', '_model_ref', '_debug_ref', '_blockswap_wrapped']:
                    if hasattr(block, attr):
                        delattr(block, attr)
        
        if restored_count > 0:
            debug.log(f"Restored {restored_count} block forward methods", category="success")

    # 2. Restore RoPE patches
    if hasattr(model, '_rope_patches'):
        for module, original_method in model._rope_patches:
            module.get_axial_freqs = original_method
            # Clean up wrapper attributes
            for attr in ['_rope_wrapped', '_original_get_axial_freqs']:
                if hasattr(module, attr):
                    delattr(module, attr)
        debug.log(f"Restored {len(model._rope_patches)} RoPE methods", category="success")
        delattr(model, '_rope_patches')

    # 3. Restore I/O component forward methods and move to offload device
    if hasattr(model, '_io_swappers'):
        for module, module_name in model._io_swappers:
            if hasattr(module, '_original_forward'):
                module.forward = module._original_forward
                # Clean up wrapper attributes
                for attr in ['_original_forward', '_model_ref', '_debug_ref', 
                           '_module_name', '_is_io_wrapped']:
                    if hasattr(module, attr):
                        delattr(module, attr)
        debug.log(f"Restored {len(model._io_swappers)} I/O components", category="success")
        delattr(model, '_io_swappers')
    
    # Move all IO components to offload device during full cleanup
    if hasattr(model, 'offload_device'):
        offload_device = model.offload_device
        moved_count = 0
        for name, module in model.named_children():
            if name != "blocks":
                module.to(offload_device)
                moved_count += 1
        if moved_count > 0:
            debug.log(f"Moved {moved_count} IO components to offload device", category="success")

    # 4. Restore original .to() method
    if hasattr(model, '_original_to'):
        model.to = model._original_to
        delattr(model, '_original_to')
        debug.log("Restored original .to() method", category="success")

    # 5. Clean up BlockSwap-specific attributes
    for attr in ['_blockswap_runner_ref', 'blocks_to_swap', 'main_device', 
                 'offload_device']:
        if hasattr(model, attr):
            delattr(model, attr)

    # 6. Clean up runner attributes
    runner._blockswap_active = False
    
    # Clean up pipelining resources on model (synchronize first)
    if hasattr(model, '_swap_stream'):
        try:
            model._swap_stream.synchronize()
        except Exception:
            pass
    for attr in ['_swap_stream', '_block_ready_events']:
        if hasattr(model, attr):
            delattr(model, attr)

    # Remove all config attributes
    for attr in ['_cached_blockswap_config', '_block_swap_config', '_blockswap_debug']:
        if hasattr(runner, attr):
            delattr(runner, attr)
    
    debug.log("BlockSwap cleanup complete", category="success")