Spaces:
Running on Zero
Running on Zero
| # Copyright 2025 The HuggingFace Team. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import inspect | |
| from dataclasses import dataclass | |
| from typing import Any, Callable, Dict, Type | |
| class AttentionProcessorMetadata: | |
| skip_processor_output_fn: Callable[[Any], Any] | |
| class TransformerBlockMetadata: | |
| return_hidden_states_index: int = None | |
| return_encoder_hidden_states_index: int = None | |
| _cls: Type = None | |
| _cached_parameter_indices: Dict[str, int] = None | |
| def _get_parameter_from_args_kwargs(self, identifier: str, args=(), kwargs=None): | |
| kwargs = kwargs or {} | |
| if identifier in kwargs: | |
| return kwargs[identifier] | |
| if self._cached_parameter_indices is not None: | |
| return args[self._cached_parameter_indices[identifier]] | |
| if self._cls is None: | |
| raise ValueError("Model class is not set for metadata.") | |
| parameters = list(inspect.signature(self._cls.forward).parameters.keys()) | |
| parameters = parameters[1:] # skip `self` | |
| self._cached_parameter_indices = {param: i for i, param in enumerate(parameters)} | |
| if identifier not in self._cached_parameter_indices: | |
| raise ValueError(f"Parameter '{identifier}' not found in function signature but was requested.") | |
| index = self._cached_parameter_indices[identifier] | |
| if index >= len(args): | |
| raise ValueError(f"Expected {index} arguments but got {len(args)}.") | |
| return args[index] | |
| class AttentionProcessorRegistry: | |
| _registry = {} | |
| # TODO(aryan): this is only required for the time being because we need to do the registrations | |
| # for classes. If we do it eagerly, i.e. call the functions in global scope, we will get circular | |
| # import errors because of the models imported in this file. | |
| _is_registered = False | |
| def register(cls, model_class: Type, metadata: AttentionProcessorMetadata): | |
| cls._register() | |
| cls._registry[model_class] = metadata | |
| def get(cls, model_class: Type) -> AttentionProcessorMetadata: | |
| cls._register() | |
| if model_class not in cls._registry: | |
| raise ValueError(f"Model class {model_class} not registered.") | |
| return cls._registry[model_class] | |
| def _register(cls): | |
| if cls._is_registered: | |
| return | |
| cls._is_registered = True | |
| _register_attention_processors_metadata() | |
| class TransformerBlockRegistry: | |
| _registry = {} | |
| # TODO(aryan): this is only required for the time being because we need to do the registrations | |
| # for classes. If we do it eagerly, i.e. call the functions in global scope, we will get circular | |
| # import errors because of the models imported in this file. | |
| _is_registered = False | |
| def register(cls, model_class: Type, metadata: TransformerBlockMetadata): | |
| cls._register() | |
| metadata._cls = model_class | |
| cls._registry[model_class] = metadata | |
| def get(cls, model_class: Type) -> TransformerBlockMetadata: | |
| cls._register() | |
| if model_class not in cls._registry: | |
| raise ValueError(f"Model class {model_class} not registered.") | |
| return cls._registry[model_class] | |
| def _register(cls): | |
| if cls._is_registered: | |
| return | |
| cls._is_registered = True | |
| _register_transformer_blocks_metadata() | |
| def _register_attention_processors_metadata(): | |
| from ..models.attention_processor import AttnProcessor2_0 | |
| from ..models.transformers.transformer_cogview4 import CogView4AttnProcessor | |
| from ..models.transformers.transformer_flux import FluxAttnProcessor | |
| from ..models.transformers.transformer_hunyuanimage import HunyuanImageAttnProcessor | |
| from ..models.transformers.transformer_qwenimage import QwenDoubleStreamAttnProcessor2_0 | |
| from ..models.transformers.transformer_wan import WanAttnProcessor2_0 | |
| from ..models.transformers.transformer_z_image import ZSingleStreamAttnProcessor | |
| # AttnProcessor2_0 | |
| AttentionProcessorRegistry.register( | |
| model_class=AttnProcessor2_0, | |
| metadata=AttentionProcessorMetadata( | |
| skip_processor_output_fn=_skip_proc_output_fn_Attention_AttnProcessor2_0, | |
| ), | |
| ) | |
| # CogView4AttnProcessor | |
| AttentionProcessorRegistry.register( | |
| model_class=CogView4AttnProcessor, | |
| metadata=AttentionProcessorMetadata( | |
| skip_processor_output_fn=_skip_proc_output_fn_Attention_CogView4AttnProcessor, | |
| ), | |
| ) | |
| # WanAttnProcessor2_0 | |
| AttentionProcessorRegistry.register( | |
| model_class=WanAttnProcessor2_0, | |
| metadata=AttentionProcessorMetadata( | |
| skip_processor_output_fn=_skip_proc_output_fn_Attention_WanAttnProcessor2_0, | |
| ), | |
| ) | |
| # FluxAttnProcessor | |
| AttentionProcessorRegistry.register( | |
| model_class=FluxAttnProcessor, | |
| metadata=AttentionProcessorMetadata(skip_processor_output_fn=_skip_proc_output_fn_Attention_FluxAttnProcessor), | |
| ) | |
| # QwenDoubleStreamAttnProcessor2 | |
| AttentionProcessorRegistry.register( | |
| model_class=QwenDoubleStreamAttnProcessor2_0, | |
| metadata=AttentionProcessorMetadata( | |
| skip_processor_output_fn=_skip_proc_output_fn_Attention_QwenDoubleStreamAttnProcessor2_0 | |
| ), | |
| ) | |
| # HunyuanImageAttnProcessor | |
| AttentionProcessorRegistry.register( | |
| model_class=HunyuanImageAttnProcessor, | |
| metadata=AttentionProcessorMetadata( | |
| skip_processor_output_fn=_skip_proc_output_fn_Attention_HunyuanImageAttnProcessor, | |
| ), | |
| ) | |
| # ZSingleStreamAttnProcessor | |
| AttentionProcessorRegistry.register( | |
| model_class=ZSingleStreamAttnProcessor, | |
| metadata=AttentionProcessorMetadata( | |
| skip_processor_output_fn=_skip_proc_output_fn_Attention_ZSingleStreamAttnProcessor, | |
| ), | |
| ) | |
| def _register_transformer_blocks_metadata(): | |
| from ..models.attention import BasicTransformerBlock | |
| from ..models.transformers.cogvideox_transformer_3d import CogVideoXBlock | |
| from ..models.transformers.transformer_bria import BriaTransformerBlock | |
| from ..models.transformers.transformer_cogview4 import CogView4TransformerBlock | |
| from ..models.transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock | |
| from ..models.transformers.transformer_hunyuan_video import ( | |
| HunyuanVideoSingleTransformerBlock, | |
| HunyuanVideoTokenReplaceSingleTransformerBlock, | |
| HunyuanVideoTokenReplaceTransformerBlock, | |
| HunyuanVideoTransformerBlock, | |
| ) | |
| from ..models.transformers.transformer_hunyuanimage import ( | |
| HunyuanImageSingleTransformerBlock, | |
| HunyuanImageTransformerBlock, | |
| ) | |
| from ..models.transformers.transformer_ltx import LTXVideoTransformerBlock | |
| from ..models.transformers.transformer_mochi import MochiTransformerBlock | |
| from ..models.transformers.transformer_qwenimage import QwenImageTransformerBlock | |
| from ..models.transformers.transformer_wan import WanTransformerBlock | |
| from ..models.transformers.transformer_z_image import ZImageTransformerBlock | |
| # BasicTransformerBlock | |
| TransformerBlockRegistry.register( | |
| model_class=BasicTransformerBlock, | |
| metadata=TransformerBlockMetadata( | |
| return_hidden_states_index=0, | |
| return_encoder_hidden_states_index=None, | |
| ), | |
| ) | |
| TransformerBlockRegistry.register( | |
| model_class=BriaTransformerBlock, | |
| metadata=TransformerBlockMetadata( | |
| return_hidden_states_index=0, | |
| return_encoder_hidden_states_index=None, | |
| ), | |
| ) | |
| # CogVideoX | |
| TransformerBlockRegistry.register( | |
| model_class=CogVideoXBlock, | |
| metadata=TransformerBlockMetadata( | |
| return_hidden_states_index=0, | |
| return_encoder_hidden_states_index=1, | |
| ), | |
| ) | |
| # CogView4 | |
| TransformerBlockRegistry.register( | |
| model_class=CogView4TransformerBlock, | |
| metadata=TransformerBlockMetadata( | |
| return_hidden_states_index=0, | |
| return_encoder_hidden_states_index=1, | |
| ), | |
| ) | |
| # Flux | |
| TransformerBlockRegistry.register( | |
| model_class=FluxTransformerBlock, | |
| metadata=TransformerBlockMetadata( | |
| return_hidden_states_index=1, | |
| return_encoder_hidden_states_index=0, | |
| ), | |
| ) | |
| TransformerBlockRegistry.register( | |
| model_class=FluxSingleTransformerBlock, | |
| metadata=TransformerBlockMetadata( | |
| return_hidden_states_index=1, | |
| return_encoder_hidden_states_index=0, | |
| ), | |
| ) | |
| # HunyuanVideo | |
| TransformerBlockRegistry.register( | |
| model_class=HunyuanVideoTransformerBlock, | |
| metadata=TransformerBlockMetadata( | |
| return_hidden_states_index=0, | |
| return_encoder_hidden_states_index=1, | |
| ), | |
| ) | |
| TransformerBlockRegistry.register( | |
| model_class=HunyuanVideoSingleTransformerBlock, | |
| metadata=TransformerBlockMetadata( | |
| return_hidden_states_index=0, | |
| return_encoder_hidden_states_index=1, | |
| ), | |
| ) | |
| TransformerBlockRegistry.register( | |
| model_class=HunyuanVideoTokenReplaceTransformerBlock, | |
| metadata=TransformerBlockMetadata( | |
| return_hidden_states_index=0, | |
| return_encoder_hidden_states_index=1, | |
| ), | |
| ) | |
| TransformerBlockRegistry.register( | |
| model_class=HunyuanVideoTokenReplaceSingleTransformerBlock, | |
| metadata=TransformerBlockMetadata( | |
| return_hidden_states_index=0, | |
| return_encoder_hidden_states_index=1, | |
| ), | |
| ) | |
| # LTXVideo | |
| TransformerBlockRegistry.register( | |
| model_class=LTXVideoTransformerBlock, | |
| metadata=TransformerBlockMetadata( | |
| return_hidden_states_index=0, | |
| return_encoder_hidden_states_index=None, | |
| ), | |
| ) | |
| # Mochi | |
| TransformerBlockRegistry.register( | |
| model_class=MochiTransformerBlock, | |
| metadata=TransformerBlockMetadata( | |
| return_hidden_states_index=0, | |
| return_encoder_hidden_states_index=1, | |
| ), | |
| ) | |
| # Wan | |
| TransformerBlockRegistry.register( | |
| model_class=WanTransformerBlock, | |
| metadata=TransformerBlockMetadata( | |
| return_hidden_states_index=0, | |
| return_encoder_hidden_states_index=None, | |
| ), | |
| ) | |
| # QwenImage | |
| TransformerBlockRegistry.register( | |
| model_class=QwenImageTransformerBlock, | |
| metadata=TransformerBlockMetadata( | |
| return_hidden_states_index=1, | |
| return_encoder_hidden_states_index=0, | |
| ), | |
| ) | |
| # HunyuanImage2.1 | |
| TransformerBlockRegistry.register( | |
| model_class=HunyuanImageTransformerBlock, | |
| metadata=TransformerBlockMetadata( | |
| return_hidden_states_index=0, | |
| return_encoder_hidden_states_index=1, | |
| ), | |
| ) | |
| TransformerBlockRegistry.register( | |
| model_class=HunyuanImageSingleTransformerBlock, | |
| metadata=TransformerBlockMetadata( | |
| return_hidden_states_index=0, | |
| return_encoder_hidden_states_index=1, | |
| ), | |
| ) | |
| # ZImage | |
| TransformerBlockRegistry.register( | |
| model_class=ZImageTransformerBlock, | |
| metadata=TransformerBlockMetadata( | |
| return_hidden_states_index=0, | |
| return_encoder_hidden_states_index=None, | |
| ), | |
| ) | |
| # fmt: off | |
| def _skip_attention___ret___hidden_states(self, *args, **kwargs): | |
| hidden_states = kwargs.get("hidden_states", None) | |
| if hidden_states is None and len(args) > 0: | |
| hidden_states = args[0] | |
| return hidden_states | |
| def _skip_attention___ret___hidden_states___encoder_hidden_states(self, *args, **kwargs): | |
| hidden_states = kwargs.get("hidden_states", None) | |
| encoder_hidden_states = kwargs.get("encoder_hidden_states", None) | |
| if hidden_states is None and len(args) > 0: | |
| hidden_states = args[0] | |
| if encoder_hidden_states is None and len(args) > 1: | |
| encoder_hidden_states = args[1] | |
| return hidden_states, encoder_hidden_states | |
| _skip_proc_output_fn_Attention_AttnProcessor2_0 = _skip_attention___ret___hidden_states | |
| _skip_proc_output_fn_Attention_CogView4AttnProcessor = _skip_attention___ret___hidden_states___encoder_hidden_states | |
| _skip_proc_output_fn_Attention_WanAttnProcessor2_0 = _skip_attention___ret___hidden_states | |
| # not sure what this is yet. | |
| _skip_proc_output_fn_Attention_FluxAttnProcessor = _skip_attention___ret___hidden_states | |
| _skip_proc_output_fn_Attention_QwenDoubleStreamAttnProcessor2_0 = _skip_attention___ret___hidden_states | |
| _skip_proc_output_fn_Attention_HunyuanImageAttnProcessor = _skip_attention___ret___hidden_states | |
| _skip_proc_output_fn_Attention_ZSingleStreamAttnProcessor = _skip_attention___ret___hidden_states | |
| # fmt: on | |