| from ..utils import log | |
| import torch | |
| def set_transformer_cache_method(transformer, timesteps, cache_args=None): | |
| transformer.cache_device = cache_args["cache_device"] | |
| if cache_args["cache_type"] == "TeaCache": | |
| log.info(f"TeaCache: Using cache device: {transformer.cache_device}") | |
| transformer.teacache_state.clear_all() | |
| transformer.enable_teacache = True | |
| transformer.rel_l1_thresh = cache_args["rel_l1_thresh"] | |
| transformer.teacache_start_step = cache_args["start_step"] | |
| transformer.teacache_end_step = len(timesteps)-1 if cache_args["end_step"] == -1 else cache_args["end_step"] | |
| transformer.teacache_use_coefficients = cache_args["use_coefficients"] | |
| transformer.teacache_mode = cache_args["mode"] | |
| elif cache_args["cache_type"] == "MagCache": | |
| log.info(f"MagCache: Using cache device: {transformer.cache_device}") | |
| transformer.magcache_state.clear_all() | |
| transformer.enable_magcache = True | |
| transformer.magcache_start_step = cache_args["start_step"] | |
| transformer.magcache_end_step = len(timesteps)-1 if cache_args["end_step"] == -1 else cache_args["end_step"] | |
| transformer.magcache_thresh = cache_args["magcache_thresh"] | |
| transformer.magcache_K = cache_args["magcache_K"] | |
| elif cache_args["cache_type"] == "EasyCache": | |
| log.info(f"EasyCache: Using cache device: {transformer.cache_device}") | |
| transformer.easycache_state.clear_all() | |
| transformer.enable_easycache = True | |
| transformer.easycache_start_step = cache_args["start_step"] | |
| transformer.easycache_end_step = len(timesteps)-1 if cache_args["end_step"] == -1 else cache_args["end_step"] | |
| transformer.easycache_thresh = cache_args["easycache_thresh"] | |
| return transformer | |
| class TeaCacheState: | |
| def __init__(self, cache_device='cpu'): | |
| self.cache_device = cache_device | |
| self.states = {} | |
| self._next_pred_id = 0 | |
| def new_prediction(self, cache_device='cpu'): | |
| """Create new prediction state and return its ID""" | |
| self.cache_device = cache_device | |
| pred_id = self._next_pred_id | |
| self._next_pred_id += 1 | |
| self.states[pred_id] = { | |
| 'previous_residual': None, | |
| 'accumulated_rel_l1_distance': 0, | |
| 'previous_modulated_input': None, | |
| 'skipped_steps': [], | |
| } | |
| return pred_id | |
| def update(self, pred_id, **kwargs): | |
| """Update state for specific prediction""" | |
| if pred_id not in self.states: | |
| return None | |
| for key, value in kwargs.items(): | |
| self.states[pred_id][key] = value | |
| def get(self, pred_id): | |
| return self.states.get(pred_id, {}) | |
| def clear_all(self): | |
| self.states = {} | |
| self._next_pred_id = 0 | |
| class MagCacheState: | |
| def __init__(self, cache_device='cpu'): | |
| self.cache_device = cache_device | |
| self.states = {} | |
| self._next_pred_id = 0 | |
| def new_prediction(self, cache_device='cpu'): | |
| """Create new prediction state and return its ID""" | |
| self.cache_device = cache_device | |
| pred_id = self._next_pred_id | |
| self._next_pred_id += 1 | |
| self.states[pred_id] = { | |
| 'residual_cache': None, | |
| 'accumulated_ratio': 1.0, | |
| 'accumulated_steps': 0, | |
| 'accumulated_err': 0, | |
| 'skipped_steps': [], | |
| } | |
| return pred_id | |
| def update(self, pred_id, **kwargs): | |
| """Update state for specific prediction""" | |
| if pred_id not in self.states: | |
| return None | |
| for key, value in kwargs.items(): | |
| self.states[pred_id][key] = value | |
| def get(self, pred_id): | |
| return self.states.get(pred_id, {}) | |
| def clear_all(self): | |
| self.states = {} | |
| self._next_pred_id = 0 | |
| class EasyCacheState: | |
| def __init__(self, cache_device='cpu'): | |
| self.cache_device = cache_device | |
| self.states = {} | |
| self._next_pred_id = 0 | |
| def new_prediction(self, cache_device='cpu'): | |
| """Create a new prediction state and return its ID.""" | |
| self.cache_device = cache_device | |
| pred_id = self._next_pred_id | |
| self._next_pred_id += 1 | |
| self.states[pred_id] = { | |
| 'previous_raw_input': None, | |
| 'previous_raw_output': None, | |
| 'cache': None, | |
| 'accumulated_error': 0.0, | |
| 'skipped_steps': [], | |
| } | |
| return pred_id | |
| def update(self, pred_id, **kwargs): | |
| """Update state for a specific prediction.""" | |
| if pred_id not in self.states: | |
| return None | |
| for key, value in kwargs.items(): | |
| self.states[pred_id][key] = value | |
| def get(self, pred_id): | |
| return self.states.get(pred_id, {}) | |
| def clear_all(self): | |
| self.states = {} | |
| self._next_pred_id = 0 | |
| def relative_l1_distance(last_tensor, current_tensor): | |
| l1_distance = torch.abs(last_tensor.to(current_tensor.device) - current_tensor).mean() | |
| norm = torch.abs(last_tensor).mean() | |
| relative_l1_distance = l1_distance / norm | |
| return relative_l1_distance.to(torch.float32).to(current_tensor.device) | |
| def cache_report(transformer, cache_args): | |
| cache_type = cache_args["cache_type"] | |
| states = ( | |
| transformer.teacache_state.states if cache_type == "TeaCache" else | |
| transformer.magcache_state.states if cache_type == "MagCache" else | |
| transformer.easycache_state.states if cache_type == "EasyCache" else | |
| None | |
| ) | |
| state_names = { | |
| 0: "conditional", | |
| 1: "unconditional" | |
| } | |
| for pred_id, state in states.items(): | |
| name = state_names.get(pred_id, f"prediction_{pred_id}") | |
| if 'skipped_steps' in state: | |
| log.info(f"{cache_type} skipped: {len(state['skipped_steps'])} {name} steps: {state['skipped_steps']}") | |
| transformer.teacache_state.clear_all() | |
| transformer.magcache_state.clear_all() | |
| transformer.easycache_state.clear_all() | |
| del states |