File size: 6,047 Bytes
cf812a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
from ..utils import log
import torch

def set_transformer_cache_method(transformer, timesteps, cache_args=None):      
    transformer.cache_device = cache_args["cache_device"]
    if cache_args["cache_type"] == "TeaCache":
        log.info(f"TeaCache: Using cache device: {transformer.cache_device}")
        transformer.teacache_state.clear_all()
        transformer.enable_teacache = True
        transformer.rel_l1_thresh = cache_args["rel_l1_thresh"]
        transformer.teacache_start_step = cache_args["start_step"]
        transformer.teacache_end_step = len(timesteps)-1 if cache_args["end_step"] == -1 else cache_args["end_step"]
        transformer.teacache_use_coefficients = cache_args["use_coefficients"]
        transformer.teacache_mode = cache_args["mode"]
    elif cache_args["cache_type"] == "MagCache":
        log.info(f"MagCache: Using cache device: {transformer.cache_device}")
        transformer.magcache_state.clear_all()
        transformer.enable_magcache = True
        transformer.magcache_start_step = cache_args["start_step"]
        transformer.magcache_end_step = len(timesteps)-1 if cache_args["end_step"] == -1 else cache_args["end_step"]
        transformer.magcache_thresh = cache_args["magcache_thresh"]
        transformer.magcache_K = cache_args["magcache_K"]
    elif cache_args["cache_type"] == "EasyCache":
        log.info(f"EasyCache: Using cache device: {transformer.cache_device}")
        transformer.easycache_state.clear_all()
        transformer.enable_easycache = True
        transformer.easycache_start_step = cache_args["start_step"]
        transformer.easycache_end_step = len(timesteps)-1 if cache_args["end_step"] == -1 else cache_args["end_step"]
        transformer.easycache_thresh = cache_args["easycache_thresh"]
    return transformer

class TeaCacheState:
    def __init__(self, cache_device='cpu'):
        self.cache_device = cache_device
        self.states = {}
        self._next_pred_id = 0
    
    def new_prediction(self, cache_device='cpu'):
        """Create new prediction state and return its ID"""
        self.cache_device = cache_device
        pred_id = self._next_pred_id
        self._next_pred_id += 1
        self.states[pred_id] = {
            'previous_residual': None,
            'accumulated_rel_l1_distance': 0,
            'previous_modulated_input': None,
            'skipped_steps': [],
        }
        return pred_id
    
    def update(self, pred_id, **kwargs):
        """Update state for specific prediction"""
        if pred_id not in self.states:
            return None
        for key, value in kwargs.items():
            self.states[pred_id][key] = value
    
    def get(self, pred_id):
        return self.states.get(pred_id, {})
    
    def clear_all(self):
        self.states = {}
        self._next_pred_id = 0

class MagCacheState:
    def __init__(self, cache_device='cpu'):
        self.cache_device = cache_device
        self.states = {}
        self._next_pred_id = 0
    
    def new_prediction(self, cache_device='cpu'):
        """Create new prediction state and return its ID"""
        self.cache_device = cache_device
        pred_id = self._next_pred_id
        self._next_pred_id += 1
        self.states[pred_id] = {
            'residual_cache': None,
            'accumulated_ratio': 1.0,
            'accumulated_steps': 0,
            'accumulated_err': 0,
            'skipped_steps': [],
        }
        return pred_id
    
    def update(self, pred_id, **kwargs):
        """Update state for specific prediction"""
        if pred_id not in self.states:
            return None
        for key, value in kwargs.items():
            self.states[pred_id][key] = value
    
    def get(self, pred_id):
        return self.states.get(pred_id, {})
    
    def clear_all(self):
        self.states = {}
        self._next_pred_id = 0

class EasyCacheState:
    def __init__(self, cache_device='cpu'):
        self.cache_device = cache_device
        self.states = {}
        self._next_pred_id = 0

    def new_prediction(self, cache_device='cpu'):
        """Create a new prediction state and return its ID."""
        self.cache_device = cache_device
        pred_id = self._next_pred_id
        self._next_pred_id += 1
        self.states[pred_id] = {
            'previous_raw_input': None,
            'previous_raw_output': None,
            'cache': None,
            'accumulated_error': 0.0,
            'skipped_steps': [],
        }
        return pred_id

    def update(self, pred_id, **kwargs):
        """Update state for a specific prediction."""
        if pred_id not in self.states:
            return None
        for key, value in kwargs.items():
            self.states[pred_id][key] = value

    def get(self, pred_id):
        return self.states.get(pred_id, {})

    def clear_all(self):
        self.states = {}
        self._next_pred_id = 0

def relative_l1_distance(last_tensor, current_tensor):
    l1_distance = torch.abs(last_tensor.to(current_tensor.device) - current_tensor).mean()
    norm = torch.abs(last_tensor).mean()
    relative_l1_distance = l1_distance / norm
    return relative_l1_distance.to(torch.float32).to(current_tensor.device)

def cache_report(transformer, cache_args):
    cache_type = cache_args["cache_type"]
    states = (
        transformer.teacache_state.states if cache_type == "TeaCache" else
        transformer.magcache_state.states if cache_type == "MagCache" else
        transformer.easycache_state.states if cache_type == "EasyCache" else
        None
    )
    state_names = {
        0: "conditional",
        1: "unconditional"
    }
    for pred_id, state in states.items():
        name = state_names.get(pred_id, f"prediction_{pred_id}")
        if 'skipped_steps' in state:
            log.info(f"{cache_type} skipped: {len(state['skipped_steps'])} {name} steps: {state['skipped_steps']}")
    transformer.teacache_state.clear_all()
    transformer.magcache_state.clear_all()
    transformer.easycache_state.clear_all()
    del states