aliensmn's picture
Mirror from https://github.com/kijai/ComfyUI-WanVideoWrapper
cf812a0 verified
from ..utils import log
import torch
def set_transformer_cache_method(transformer, timesteps, cache_args=None):
transformer.cache_device = cache_args["cache_device"]
if cache_args["cache_type"] == "TeaCache":
log.info(f"TeaCache: Using cache device: {transformer.cache_device}")
transformer.teacache_state.clear_all()
transformer.enable_teacache = True
transformer.rel_l1_thresh = cache_args["rel_l1_thresh"]
transformer.teacache_start_step = cache_args["start_step"]
transformer.teacache_end_step = len(timesteps)-1 if cache_args["end_step"] == -1 else cache_args["end_step"]
transformer.teacache_use_coefficients = cache_args["use_coefficients"]
transformer.teacache_mode = cache_args["mode"]
elif cache_args["cache_type"] == "MagCache":
log.info(f"MagCache: Using cache device: {transformer.cache_device}")
transformer.magcache_state.clear_all()
transformer.enable_magcache = True
transformer.magcache_start_step = cache_args["start_step"]
transformer.magcache_end_step = len(timesteps)-1 if cache_args["end_step"] == -1 else cache_args["end_step"]
transformer.magcache_thresh = cache_args["magcache_thresh"]
transformer.magcache_K = cache_args["magcache_K"]
elif cache_args["cache_type"] == "EasyCache":
log.info(f"EasyCache: Using cache device: {transformer.cache_device}")
transformer.easycache_state.clear_all()
transformer.enable_easycache = True
transformer.easycache_start_step = cache_args["start_step"]
transformer.easycache_end_step = len(timesteps)-1 if cache_args["end_step"] == -1 else cache_args["end_step"]
transformer.easycache_thresh = cache_args["easycache_thresh"]
return transformer
class TeaCacheState:
def __init__(self, cache_device='cpu'):
self.cache_device = cache_device
self.states = {}
self._next_pred_id = 0
def new_prediction(self, cache_device='cpu'):
"""Create new prediction state and return its ID"""
self.cache_device = cache_device
pred_id = self._next_pred_id
self._next_pred_id += 1
self.states[pred_id] = {
'previous_residual': None,
'accumulated_rel_l1_distance': 0,
'previous_modulated_input': None,
'skipped_steps': [],
}
return pred_id
def update(self, pred_id, **kwargs):
"""Update state for specific prediction"""
if pred_id not in self.states:
return None
for key, value in kwargs.items():
self.states[pred_id][key] = value
def get(self, pred_id):
return self.states.get(pred_id, {})
def clear_all(self):
self.states = {}
self._next_pred_id = 0
class MagCacheState:
def __init__(self, cache_device='cpu'):
self.cache_device = cache_device
self.states = {}
self._next_pred_id = 0
def new_prediction(self, cache_device='cpu'):
"""Create new prediction state and return its ID"""
self.cache_device = cache_device
pred_id = self._next_pred_id
self._next_pred_id += 1
self.states[pred_id] = {
'residual_cache': None,
'accumulated_ratio': 1.0,
'accumulated_steps': 0,
'accumulated_err': 0,
'skipped_steps': [],
}
return pred_id
def update(self, pred_id, **kwargs):
"""Update state for specific prediction"""
if pred_id not in self.states:
return None
for key, value in kwargs.items():
self.states[pred_id][key] = value
def get(self, pred_id):
return self.states.get(pred_id, {})
def clear_all(self):
self.states = {}
self._next_pred_id = 0
class EasyCacheState:
def __init__(self, cache_device='cpu'):
self.cache_device = cache_device
self.states = {}
self._next_pred_id = 0
def new_prediction(self, cache_device='cpu'):
"""Create a new prediction state and return its ID."""
self.cache_device = cache_device
pred_id = self._next_pred_id
self._next_pred_id += 1
self.states[pred_id] = {
'previous_raw_input': None,
'previous_raw_output': None,
'cache': None,
'accumulated_error': 0.0,
'skipped_steps': [],
}
return pred_id
def update(self, pred_id, **kwargs):
"""Update state for a specific prediction."""
if pred_id not in self.states:
return None
for key, value in kwargs.items():
self.states[pred_id][key] = value
def get(self, pred_id):
return self.states.get(pred_id, {})
def clear_all(self):
self.states = {}
self._next_pred_id = 0
def relative_l1_distance(last_tensor, current_tensor):
l1_distance = torch.abs(last_tensor.to(current_tensor.device) - current_tensor).mean()
norm = torch.abs(last_tensor).mean()
relative_l1_distance = l1_distance / norm
return relative_l1_distance.to(torch.float32).to(current_tensor.device)
def cache_report(transformer, cache_args):
cache_type = cache_args["cache_type"]
states = (
transformer.teacache_state.states if cache_type == "TeaCache" else
transformer.magcache_state.states if cache_type == "MagCache" else
transformer.easycache_state.states if cache_type == "EasyCache" else
None
)
state_names = {
0: "conditional",
1: "unconditional"
}
for pred_id, state in states.items():
name = state_names.get(pred_id, f"prediction_{pred_id}")
if 'skipped_steps' in state:
log.info(f"{cache_type} skipped: {len(state['skipped_steps'])} {name} steps: {state['skipped_steps']}")
transformer.teacache_state.clear_all()
transformer.magcache_state.clear_all()
transformer.easycache_state.clear_all()
del states