Spaces:
Running on Zero
Running on Zero
| from dataclasses import dataclass | |
| from typing import Any, Dict, List, Optional, Tuple, Union | |
| import torch | |
| from diffusers.models.unets.unet_2d_condition import UNet2DConditionOutput | |
| class CustomOutput(UNet2DConditionOutput): | |
| """ | |
| The output of [`UNet2DConditionModel`]. | |
| Args: | |
| sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)`): | |
| The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model. | |
| """ | |
| sample: torch.Tensor = None | |
| block_outputs: List[torch.Tensor] = None | |
| cross_attention_maps: List[torch.Tensor] = None | |
| common_args = { | |
| "sample_size": None, | |
| "in_channels": 4, | |
| "out_channels": 4, | |
| "center_input_sample": False, | |
| "flip_sin_to_cos": True, | |
| "freq_shift": 0, | |
| "down_block_types": ( | |
| "CrossAttnDownBlock2D", | |
| "CrossAttnDownBlock2D", | |
| "CrossAttnDownBlock2D", | |
| "DownBlock2D", | |
| ), | |
| "mid_block_type": "UNetMidBlock2DCrossAttn", | |
| "up_block_types": ( | |
| "UpBlock2D", | |
| "CrossAttnUpBlock2D", | |
| "CrossAttnUpBlock2D", | |
| "CrossAttnUpBlock2D", | |
| ), | |
| "only_cross_attention": False, | |
| "block_out_channels": (320, 640, 1280, 1280), | |
| "layers_per_block": 2, | |
| "downsample_padding": 1, | |
| "mid_block_scale_factor": 1, | |
| "dropout": 0.0, | |
| "act_fn": "silu", | |
| "norm_num_groups": 32, | |
| "norm_eps": 1e-5, | |
| "cross_attention_dim": 1280, | |
| "transformer_layers_per_block": 1, | |
| "reverse_transformer_layers_per_block": None, | |
| "encoder_hid_dim": None, | |
| "encoder_hid_dim_type": None, | |
| "attention_head_dim": 8, | |
| "num_attention_heads": None, | |
| "dual_cross_attention": False, | |
| "use_linear_projection": False, | |
| "class_embed_type": None, | |
| "addition_embed_type": None, | |
| "addition_time_embed_dim": None, | |
| "num_class_embeds": None, | |
| "upcast_attention": False, | |
| "resnet_time_scale_shift": "default", | |
| "resnet_skip_time_act": False, | |
| "resnet_out_scale_factor": 1.0, | |
| "time_embedding_type": "positional", | |
| "time_embedding_dim": None, | |
| "time_embedding_act_fn": None, | |
| "timestep_post_act": None, | |
| "time_cond_proj_dim": None, | |
| "conv_in_kernel": 3, | |
| "conv_out_kernel": 3, | |
| "projection_class_embeddings_input_dim": None, | |
| "attention_type": "default", | |
| "class_embeddings_concat": False, | |
| "mid_block_only_cross_attention": None, | |
| "cross_attention_norm": None, | |
| "addition_embed_type_num_heads": 64, | |
| } | |