Spaces:
Build error
Build error
| from diffusers import ( | |
| CogVideoXTransformer3DModel, | |
| CogView4Transformer2DModel, | |
| FluxTransformer2DModel, | |
| WanTransformer3DModel, | |
| ) | |
| from finetrainers._metadata import CPInput, CPOutput, ParamId, TransformerMetadata, TransformerRegistry | |
| from finetrainers.logging import get_logger | |
| logger = get_logger() | |
| def register_transformer_metadata(): | |
| # CogVideoX | |
| TransformerRegistry.register( | |
| model_class=CogVideoXTransformer3DModel, | |
| metadata=TransformerMetadata( | |
| cp_plan={ | |
| "": { | |
| ParamId("image_rotary_emb", 5): [CPInput(0, 2), CPInput(0, 2)], | |
| }, | |
| "transformer_blocks.0": { | |
| ParamId("hidden_states", 0): CPInput(1, 3), | |
| ParamId("encoder_hidden_states", 1): CPInput(1, 3), | |
| }, | |
| "proj_out": [CPOutput(1, 3)], | |
| } | |
| ), | |
| ) | |
| # CogView4 | |
| TransformerRegistry.register( | |
| model_class=CogView4Transformer2DModel, | |
| metadata=TransformerMetadata( | |
| cp_plan={ | |
| "patch_embed": { | |
| ParamId(index=0): CPInput(1, 3, split_output=True), | |
| ParamId(index=1): CPInput(1, 3, split_output=True), | |
| }, | |
| "rope": { | |
| ParamId(index=0): CPInput(0, 2, split_output=True), | |
| ParamId(index=1): CPInput(0, 2, split_output=True), | |
| }, | |
| "proj_out": [CPOutput(1, 3)], | |
| } | |
| ), | |
| ) | |
| # Flux | |
| TransformerRegistry.register( | |
| model_class=FluxTransformer2DModel, | |
| metadata=TransformerMetadata( | |
| cp_plan={ | |
| "": { | |
| ParamId("hidden_states", 0): CPInput(1, 3), | |
| ParamId("encoder_hidden_states", 1): CPInput(1, 3), | |
| ParamId("img_ids", 4): CPInput(0, 2), | |
| ParamId("txt_ids", 5): CPInput(0, 2), | |
| }, | |
| "proj_out": [CPOutput(1, 3)], | |
| } | |
| ), | |
| ) | |
| # Wan2.1 | |
| TransformerRegistry.register( | |
| model_class=WanTransformer3DModel, | |
| metadata=TransformerMetadata( | |
| cp_plan={ | |
| "rope": { | |
| ParamId(index=0): CPInput(2, 4, split_output=True), | |
| }, | |
| "blocks.*": { | |
| ParamId("encoder_hidden_states", 1): CPInput(1, 3), | |
| }, | |
| "blocks.0": { | |
| ParamId("hidden_states", 0): CPInput(1, 3), | |
| }, | |
| "proj_out": [CPOutput(1, 3)], | |
| } | |
| ), | |
| ) | |
| logger.debug("Metadata for transformer registered") | |