File size: 3,909 Bytes
dc155d4
 
 
 
 
 
 
 
 
 
 
 
879ee4e
dc155d4
 
 
7019232
dc155d4
 
 
 
db19b41
dc155d4
db19b41
 
dc155d4
 
 
db19b41
 
 
dc155d4
 
 
 
 
 
 
 
 
 
 
 
6338de3
 
 
 
 
 
434113d
dc155d4
 
 
 
 
879ee4e
db19b41
e5acb29
bb98bd0
 
 
e5acb29
b5c3f40
 
e5acb29
bb98bd0
 
 
e5acb29
bb98bd0
34767f8
bb98bd0
e5acb29
 
dc155d4
 
 
 
 
 
 
6ff4937
dc155d4
db19b41
 
dc155d4
 
db19b41
dc155d4
 
 
db19b41
dc155d4
 
db19b41
dc155d4
 
 
db19b41
 
 
 
dc155d4
 
82d7cc1
db19b41
 
 
 
55e04d3
dc155d4
db19b41
55e04d3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
"""
"""

from typing import Any
from typing import Callable
from typing import ParamSpec

import spaces
import torch
from torch.utils._pytree import tree_map_only
from torchao.quantization import quantize_
from torchao.quantization import Float8DynamicActivationFloat8WeightConfig
from torchao.quantization import Int8WeightOnlyConfig

from optimization_utils import capture_component_call
from optimization_utils import aoti_compile
from optimization_utils import drain_module_parameters


P = ParamSpec('P')

LATENT_FRAMES_DIM = torch.export.Dim('num_latent_frames', min=8, max=81)

LATENT_PATCHED_HEIGHT_DIM = torch.export.Dim('latent_patched_height', min=30, max=52)
LATENT_PATCHED_WIDTH_DIM = torch.export.Dim('latent_patched_width', min=30, max=52)

TRANSFORMER_DYNAMIC_SHAPES = {
    'hidden_states': {
        2: LATENT_FRAMES_DIM,
        3: 2 * LATENT_PATCHED_HEIGHT_DIM,
        4: 2 * LATENT_PATCHED_WIDTH_DIM,
    },
}

INDUCTOR_CONFIGS = {
    'conv_1x1_as_mm': True,
    'epilogue_fusion': False,
    'coordinate_descent_tuning': True,
    'coordinate_descent_check_all_directions': True,
    'max_autotune': True,
    'triton.cudagraphs': True,
}

# Add before optimize_pipeline_ call
import torch._inductor.config as inductor_config

# Temporarily reduce optimization aggressiveness for startup
original_max_autotune = inductor_config.max_autotune
inductor_config.max_autotune = False
inductor_config.max_autotune_gemm = False

def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kwargs):

    @spaces.GPU(duration=1500)
    def compile_transformer():
        
        # This LoRA fusion part remains the same
        pipeline.load_lora_weights(
            "Kijai/WanVideo_comfy", 
            weight_name="Lightx2v/lightx2v_I2V_14B_480p_cfg_step_distill_rank128_bf16.safetensors", 
            adapter_name="lightx2v"
        )
        kwargs_lora = {}
        kwargs_lora["load_into_transformer_2"] = True
        pipeline.load_lora_weights(
            "Kijai/WanVideo_comfy", 
            weight_name="Lightx2v/lightx2v_I2V_14B_480p_cfg_step_distill_rank128_bf16.safetensors", 
            adapter_name="lightx2v_2", **kwargs_lora
        )
        pipeline.set_adapters(["lightx2v", "lightx2v_2"], adapter_weights=[1., 1.])
        pipeline.fuse_lora(adapter_names=["lightx2v"], lora_scale=3., components=["transformer"])
        pipeline.fuse_lora(adapter_names=["lightx2v_2"], lora_scale=1., components=["transformer_2"])
        pipeline.unload_lora_weights()
        
        with capture_component_call(pipeline, 'transformer') as call:
            pipeline(*args, **kwargs)
        
        dynamic_shapes = tree_map_only((torch.Tensor, bool), lambda t: None, call.kwargs)
        dynamic_shapes |= TRANSFORMER_DYNAMIC_SHAPES

        quantize_(pipeline.transformer, Float8DynamicActivationFloat8WeightConfig())
        quantize_(pipeline.transformer_2, Float8DynamicActivationFloat8WeightConfig())
        
        
        exported_1 = torch.export.export(
            mod=pipeline.transformer,
            args=call.args,
            kwargs=call.kwargs,
            dynamic_shapes=dynamic_shapes,
        )
        
        exported_2 = torch.export.export(
            mod=pipeline.transformer_2,
            args=call.args,
            kwargs=call.kwargs,
            dynamic_shapes=dynamic_shapes,
        )

        compiled_1 = aoti_compile(exported_1, INDUCTOR_CONFIGS)
        compiled_2 = aoti_compile(exported_2, INDUCTOR_CONFIGS)
        
        return compiled_1, compiled_2


    quantize_(pipeline.text_encoder, Int8WeightOnlyConfig())
    
    compiled_transformer_1, compiled_transformer_2 = compile_transformer()

    pipeline.transformer.forward = compiled_transformer_1
    drain_module_parameters(pipeline.transformer)

    pipeline.transformer_2.forward = compiled_transformer_2
    drain_module_parameters(pipeline.transformer_2)