Spaces:
Paused
Paused
| import numpy as np | |
| import torch | |
| import os | |
| from diffusers_helper.models.mag_cache_ratios import MAG_RATIOS_DB | |
| class MagCache: | |
| """ | |
| Implements the MagCache algorithm for skipping transformer steps during video generation. | |
| MagCache: Fast Video Generation with Magnitude-Aware Cache | |
| Zehong Ma, Longhui Wei, Feng Wang, Shiliang Zhang, Qi Tian | |
| https://arxiv.org/abs/2506.09045 | |
| https://github.com/Zehong-Ma/MagCache | |
| PR Demo defaults were threshold=0.1, max_consectutive_skips=3, retention_ratio=0.2 | |
| Changing defauults to threshold=0.1, max_consectutive_skips=2, retention_ratio=0.25 for quality vs speed tradeoff. | |
| """ | |
| def __init__(self, model_family, height, width, num_steps, is_enabled=True, is_calibrating = False, threshold=0.1, max_consectutive_skips=2, retention_ratio=0.25): | |
| self.model_family = model_family | |
| self.height = height | |
| self.width = width | |
| self.num_steps = num_steps | |
| self.is_enabled = is_enabled | |
| self.is_calibrating = is_calibrating | |
| self.threshold = threshold | |
| self.max_consectutive_skips = max_consectutive_skips | |
| self.retention_ratio = retention_ratio | |
| # total cache statistics for all sections in the entire generation | |
| self.total_cache_requests = 0 | |
| self.total_cache_hits = 0 | |
| self.mag_ratios = self._determine_mag_ratios() | |
| self._init_for_every_section() | |
| def _init_for_every_section(self): | |
| self.step_index = 0 | |
| self.steps_skipped_list = [] | |
| #Error accumulation state | |
| self.accumulated_ratio = 1.0 | |
| self.accumulated_steps = 0 | |
| self.accumulated_err = 0 | |
| # Statistics for calibration | |
| self.norm_ratio, self.norm_std, self.cos_dis = [], [], [] | |
| self.hidden_states = None | |
| self.previous_residual = None | |
| if self.is_calibrating and self.total_cache_requests > 0: | |
| print('WARNING: Resetting MagCache calibration stats for new section. Typically you only want one section per calibration job. Discarding calibration from previsou section.') | |
| def should_skip(self, hidden_states): | |
| """ | |
| Expected to be called once per step during the forward pass, for the numer of initialized steps. | |
| Determines if the current step should be skipped based on estimated accumulated error. | |
| If the step is skipped, the hidden_states should be replaced with the output of estimate_predicted_hidden_states(). | |
| Args: | |
| hidden_states: The current hidden states tensor from the transformer model. | |
| Returns: | |
| True if the step should be skipped, False otherwise | |
| """ | |
| if self.step_index == 0 or self.step_index >= self.num_steps: | |
| self._init_for_every_section() | |
| self.total_cache_requests += 1 | |
| self.hidden_states = hidden_states.clone() # Is clone needed? | |
| if self.is_calibrating: | |
| print('######################### Calibrating MagCache #########################') | |
| return False | |
| should_skip_forward = False | |
| if self.step_index>=int(self.retention_ratio*self.num_steps) and self.step_index>=1: # keep first retention_ratio steps | |
| cur_mag_ratio = self.mag_ratios[self.step_index] | |
| self.accumulated_ratio = self.accumulated_ratio*cur_mag_ratio | |
| cur_skip_err = np.abs(1-self.accumulated_ratio) | |
| self.accumulated_err += cur_skip_err | |
| self.accumulated_steps += 1 | |
| # RT_BORG: Per my conversation with Zehong Ma, this 0.06 could potentially be exposed as another tunable param. | |
| if self.accumulated_err<=self.threshold and self.accumulated_steps<=self.max_consectutive_skips and np.abs(1-cur_mag_ratio)<=0.06: | |
| should_skip_forward = True | |
| else: | |
| self.accumulated_ratio = 1.0 | |
| self.accumulated_steps = 0 | |
| self.accumulated_err = 0 | |
| if should_skip_forward: | |
| self.total_cache_hits += 1 | |
| self.steps_skipped_list.append(self.step_index) | |
| # Increment for next step | |
| self.step_index += 1 | |
| if self.step_index == self.num_steps: | |
| self.step_index = 0 | |
| return should_skip_forward | |
| def estimate_predicted_hidden_states(self): | |
| """ | |
| Should be called if and only if should_skip() returned True for the current step. | |
| Estimates the hidden states for the current step based on the previous hidden states and residual. | |
| Returns: | |
| The estimated hidden states tensor. | |
| """ | |
| return self.hidden_states + self.previous_residual | |
| def update_hidden_states(self, model_prediction_hidden_states): | |
| """ | |
| If and only if should_skip() returned False for the current step, the denoising layers should have been run, | |
| and this function should be called to compute and store the residual for future steps. | |
| Args: | |
| model_prediction_hidden_states: The hidden states tensor output from running the denoising layers. | |
| """ | |
| current_residual = model_prediction_hidden_states - self.hidden_states | |
| if self.is_calibrating: | |
| self._update_calibration_stats(current_residual) | |
| self.previous_residual = current_residual | |
| def _update_calibration_stats(self, current_residual): | |
| if self.step_index >= 1: | |
| norm_ratio = ((current_residual.norm(dim=-1)/self.previous_residual.norm(dim=-1)).mean()).item() | |
| norm_std = (current_residual.norm(dim=-1)/self.previous_residual.norm(dim=-1)).std().item() | |
| cos_dis = (1-torch.nn.functional.cosine_similarity(current_residual, self.previous_residual, dim=-1, eps=1e-8)).mean().item() | |
| self.norm_ratio.append(round(norm_ratio, 5)) | |
| self.norm_std.append(round(norm_std, 5)) | |
| self.cos_dis.append(round(cos_dis, 5)) | |
| # print(f"time: {self.step_index}, norm_ratio: {norm_ratio}, norm_std: {norm_std}, cos_dis: {cos_dis}") | |
| self.step_index += 1 | |
| if self.step_index == self.num_steps: | |
| print("norm ratio") | |
| print(self.norm_ratio) | |
| print("norm std") | |
| print(self.norm_std) | |
| print("cos_dis") | |
| print(self.cos_dis) | |
| self.step_index = 0 | |
| def _determine_mag_ratios(self): | |
| """ | |
| Determines the magnitude ratios by finding the closest resolution and step count | |
| in the pre-calibrated database. | |
| Returns: | |
| A numpy array of magnitude ratios for the specified configuration, or None if not found. | |
| """ | |
| if self.is_calibrating: | |
| return None | |
| try: | |
| # Find the closest available resolution group for the given model family | |
| resolution_groups = MAG_RATIOS_DB[self.model_family] | |
| available_resolutions = list(resolution_groups.keys()) | |
| if not available_resolutions: | |
| raise ValueError("No resolutions defined for this model family.") | |
| avg_resolution = (self.height + self.width) / 2.0 | |
| closest_resolution_key = min(available_resolutions, key=lambda r: abs(r - avg_resolution)) | |
| # Find the closest available step count for the given model/resolution | |
| steps_group = resolution_groups[closest_resolution_key] | |
| available_steps = list(steps_group.keys()) | |
| if not available_steps: | |
| raise ValueError(f"No step counts defined for resolution {closest_resolution_key}.") | |
| closest_steps = min(available_steps, key=lambda x: abs(x - self.num_steps)) | |
| base_ratios = steps_group[closest_steps] | |
| if closest_steps == self.num_steps: | |
| print(f"MagCache: Found ratios for {self.model_family}, resolution group {closest_resolution_key} ({self.width}x{self.height}), {self.num_steps} steps.") | |
| return base_ratios | |
| print(f"MagCache: Using ratios from {self.model_family}, resolution group {closest_resolution_key} ({self.width}x{self.height}), {closest_steps} steps and interpolating to {self.num_steps} steps.") | |
| return self._nearest_step_interpolation(base_ratios, self.num_steps) | |
| except KeyError: | |
| # This will catch if model_family is not in MAG_RATIOS_DB | |
| print(f"Warning: MagCache not calibrated for model family '{self.model_family}'. MagCache will not be used.") | |
| self.is_enabled = False | |
| except (ValueError, TypeError) as e: | |
| # This will catch errors if resolution keys or step keys are not numbers, or if groups are empty. | |
| print(f"Warning: Error processing MagCache DB for model family '{self.model_family}': {e}. MagCache will not be used.") | |
| self.is_enabled = False | |
| return None | |
| # Nearest interpolation function for MagCache mag_ratios | |
| def _nearest_step_interpolation(src_array, target_length): | |
| src_length = len(src_array) | |
| if target_length == 1: | |
| return np.array([src_array[-1]]) | |
| scale = (src_length - 1) / (target_length - 1) | |
| mapped_indices = np.round(np.arange(target_length) * scale).astype(int) | |
| return src_array[mapped_indices] | |
| def append_calibration_to_file(self, output_file): | |
| """ | |
| Appends tab delimited calibration data (model_family,width,height,norm_ratio) to output_file. | |
| """ | |
| if not self.is_calibrating or not self.norm_ratio: | |
| print("Calibration data can only be appended after calibration.") | |
| return False | |
| try: | |
| with open(output_file, "a") as f: | |
| # Format the data as a string | |
| calibration_set = f"{self.model_family}\t{self.width}\t{self.height}\t{self.num_steps}" | |
| # data_string = f"{calibration_set}\t{self.norm_ratio}" | |
| entry_string = f"{calibration_set}\t{self.num_steps}: np.array([1.0] + {self.norm_ratio})," | |
| # Append the data to the file | |
| f.write(entry_string + "\n") | |
| print(f"Calibration data appended to {output_file}") | |
| return True | |
| except Exception as e: | |
| print(f"Error appending calibration data: {e}") | |
| return False | |