Spaces:
Running on Zero
Running on Zero
| import logging | |
| import math | |
| from typing import Any, Dict, List, Optional | |
| import torch.nn as nn | |
| import torch | |
| from src.Utilities import util | |
| from src.AutoEncoders import ResBlock | |
| from src.NeuralNetwork import transformer | |
| from src.cond import cast | |
| from src.sample import sampling, sampling_util | |
| UNET_MAP_ATTENTIONS = {"proj_in.weight", "proj_in.bias", "proj_out.weight", "proj_out.bias", "norm.weight", "norm.bias"} | |
| TRANSFORMER_BLOCKS = { | |
| "norm1.weight", "norm1.bias", "norm2.weight", "norm2.bias", "norm3.weight", "norm3.bias", | |
| "attn1.to_q.weight", "attn1.to_k.weight", "attn1.to_v.weight", "attn1.to_out.0.weight", "attn1.to_out.0.bias", | |
| "attn2.to_q.weight", "attn2.to_k.weight", "attn2.to_v.weight", "attn2.to_out.0.weight", "attn2.to_out.0.bias", | |
| "ff.net.0.proj.weight", "ff.net.0.proj.bias", "ff.net.2.weight", "ff.net.2.bias", | |
| } | |
| UNET_MAP_RESNET = { | |
| "in_layers.2.weight": "conv1.weight", "in_layers.2.bias": "conv1.bias", | |
| "emb_layers.1.weight": "time_emb_proj.weight", "emb_layers.1.bias": "time_emb_proj.bias", | |
| "out_layers.3.weight": "conv2.weight", "out_layers.3.bias": "conv2.bias", | |
| "skip_connection.weight": "conv_shortcut.weight", "skip_connection.bias": "conv_shortcut.bias", | |
| "in_layers.0.weight": "norm1.weight", "in_layers.0.bias": "norm1.bias", | |
| "out_layers.0.weight": "norm2.weight", "out_layers.0.bias": "norm2.bias", | |
| } | |
| UNET_MAP_BASIC = { | |
| ("label_emb.0.0.weight", "class_embedding.linear_1.weight"), ("label_emb.0.0.bias", "class_embedding.linear_1.bias"), | |
| ("label_emb.0.2.weight", "class_embedding.linear_2.weight"), ("label_emb.0.2.bias", "class_embedding.linear_2.bias"), | |
| ("label_emb.0.0.weight", "add_embedding.linear_1.weight"), ("label_emb.0.0.bias", "add_embedding.linear_1.bias"), | |
| ("label_emb.0.2.weight", "add_embedding.linear_2.weight"), ("label_emb.0.2.bias", "add_embedding.linear_2.bias"), | |
| ("input_blocks.0.0.weight", "conv_in.weight"), ("input_blocks.0.0.bias", "conv_in.bias"), | |
| ("out.0.weight", "conv_norm_out.weight"), ("out.0.bias", "conv_norm_out.bias"), | |
| ("out.2.weight", "conv_out.weight"), ("out.2.bias", "conv_out.bias"), | |
| ("time_embed.0.weight", "time_embedding.linear_1.weight"), ("time_embed.0.bias", "time_embedding.linear_1.bias"), | |
| ("time_embed.2.weight", "time_embedding.linear_2.weight"), ("time_embed.2.bias", "time_embedding.linear_2.bias"), | |
| } | |
| oai_ops = cast.disable_weight_init | |
| def unet_to_diffusers(unet_config: dict) -> dict: | |
| if "num_res_blocks" not in unet_config: | |
| return {} | |
| num_res_blocks, channel_mult = unet_config["num_res_blocks"], unet_config["channel_mult"] | |
| transformer_depth, transformer_depth_output = unet_config["transformer_depth"][:], unet_config["transformer_depth_output"][:] | |
| num_blocks, transformers_mid = len(channel_mult), unet_config.get("transformer_depth_middle", None) | |
| diffusers_unet_map = {} | |
| for x in range(num_blocks): | |
| n = 1 + (num_res_blocks[x] + 1) * x | |
| for i in range(num_res_blocks[x]): | |
| for b in UNET_MAP_RESNET: | |
| diffusers_unet_map[f"down_blocks.{x}.resnets.{i}.{UNET_MAP_RESNET[b]}"] = f"input_blocks.{n}.0.{b}" | |
| num_transformers = transformer_depth.pop(0) | |
| if num_transformers > 0: | |
| for b in UNET_MAP_ATTENTIONS: | |
| diffusers_unet_map[f"down_blocks.{x}.attentions.{i}.{b}"] = f"input_blocks.{n}.1.{b}" | |
| for t in range(num_transformers): | |
| for b in TRANSFORMER_BLOCKS: | |
| diffusers_unet_map[f"down_blocks.{x}.attentions.{i}.transformer_blocks.{t}.{b}"] = f"input_blocks.{n}.1.transformer_blocks.{t}.{b}" | |
| n += 1 | |
| for k in ["weight", "bias"]: | |
| diffusers_unet_map[f"down_blocks.{x}.downsamplers.0.conv.{k}"] = f"input_blocks.{n}.0.op.{k}" | |
| for b in UNET_MAP_ATTENTIONS: | |
| diffusers_unet_map[f"mid_block.attentions.0.{b}"] = f"middle_block.1.{b}" | |
| for t in range(transformers_mid): | |
| for b in TRANSFORMER_BLOCKS: | |
| diffusers_unet_map[f"mid_block.attentions.0.transformer_blocks.{t}.{b}"] = f"middle_block.1.transformer_blocks.{t}.{b}" | |
| for i, n in enumerate([0, 2]): | |
| for b in UNET_MAP_RESNET: | |
| diffusers_unet_map[f"mid_block.resnets.{i}.{UNET_MAP_RESNET[b]}"] = f"middle_block.{n}.{b}" | |
| num_res_blocks = list(reversed(num_res_blocks)) | |
| for x in range(num_blocks): | |
| n = (num_res_blocks[x] + 1) * x | |
| for i in range(num_res_blocks[x] + 1): | |
| c = 0 | |
| for b in UNET_MAP_RESNET: | |
| diffusers_unet_map[f"up_blocks.{x}.resnets.{i}.{UNET_MAP_RESNET[b]}"] = f"output_blocks.{n}.0.{b}" | |
| c += 1 | |
| num_transformers = transformer_depth_output.pop() | |
| if num_transformers > 0: | |
| c += 1 | |
| for b in UNET_MAP_ATTENTIONS: | |
| diffusers_unet_map[f"up_blocks.{x}.attentions.{i}.{b}"] = f"output_blocks.{n}.1.{b}" | |
| for t in range(num_transformers): | |
| for b in TRANSFORMER_BLOCKS: | |
| diffusers_unet_map[f"up_blocks.{x}.attentions.{i}.transformer_blocks.{t}.{b}"] = f"output_blocks.{n}.1.transformer_blocks.{t}.{b}" | |
| if i == num_res_blocks[x]: | |
| for k in ["weight", "bias"]: | |
| diffusers_unet_map[f"up_blocks.{x}.upsamplers.0.conv.{k}"] = f"output_blocks.{n}.{c}.conv.{k}" | |
| n += 1 | |
| for k in UNET_MAP_BASIC: | |
| diffusers_unet_map[k[1]] = k[0] | |
| return diffusers_unet_map | |
| def apply_control1(h: torch.Tensor, control: any, name: str) -> torch.Tensor: | |
| return h | |
| class UNetModel1(nn.Module): | |
| def __init__(self, image_size: int, in_channels: int, model_channels: int, out_channels: int, num_res_blocks: list, | |
| dropout: float = 0, channel_mult: tuple = (1, 2, 4, 8), conv_resample: bool = True, dims: int = 2, | |
| num_classes: int = None, use_checkpoint: bool = False, dtype: torch.dtype = torch.float32, | |
| num_heads: int = -1, num_head_channels: int = -1, num_heads_upsample: int = -1, | |
| use_scale_shift_norm: bool = False, resblock_updown: bool = False, use_new_attention_order: bool = False, | |
| use_spatial_transformer: bool = False, transformer_depth: int = 1, context_dim: int = None, | |
| n_embed: int = None, legacy: bool = True, disable_self_attentions: list = None, | |
| num_attention_blocks: list = None, disable_middle_self_attn: bool = False, | |
| use_linear_in_transformer: bool = False, adm_in_channels: int = None, | |
| transformer_depth_middle: int = None, transformer_depth_output: list = None, | |
| use_temporal_resblock: bool = False, use_temporal_attention: bool = False, | |
| time_context_dim: int = None, extra_ff_mix_layer: bool = False, use_spatial_context: bool = False, | |
| merge_strategy: any = None, merge_factor: float = 0.0, video_kernel_size: int = None, | |
| disable_temporal_crossattention: bool = False, max_ddpm_temb_period: int = 10000, | |
| device: torch.device = None, operations: any = oai_ops): | |
| super().__init__() | |
| if context_dim is not None: | |
| self.context_dim = context_dim | |
| if num_heads_upsample == -1: | |
| num_heads_upsample = num_heads | |
| if num_head_channels == -1: | |
| assert num_heads != -1 | |
| self.in_channels, self.model_channels, self.out_channels = in_channels, model_channels, out_channels | |
| self.num_res_blocks, self.dropout, self.channel_mult = num_res_blocks, dropout, channel_mult | |
| self.conv_resample, self.num_classes, self.use_checkpoint = conv_resample, num_classes, use_checkpoint | |
| self.dtype, self.num_heads, self.num_head_channels = dtype, num_heads, num_head_channels | |
| self.num_heads_upsample, self.use_temporal_resblocks = num_heads_upsample, use_temporal_resblock | |
| self.predict_codebook_ids, self.default_num_video_frames = n_embed is not None, None | |
| transformer_depth, transformer_depth_output = transformer_depth[:], transformer_depth_output[:] | |
| time_embed_dim = model_channels * 4 | |
| self.time_embed = nn.Sequential( | |
| operations.Linear(model_channels, time_embed_dim, dtype=self.dtype, device=device), nn.SiLU(), | |
| operations.Linear(time_embed_dim, time_embed_dim, dtype=self.dtype, device=device)) | |
| if adm_in_channels is not None: | |
| self.label_emb = nn.Sequential( | |
| nn.Sequential( | |
| operations.Linear(adm_in_channels, time_embed_dim, dtype=self.dtype, device=device), | |
| nn.SiLU(), | |
| operations.Linear(time_embed_dim, time_embed_dim, dtype=self.dtype, device=device))) | |
| self.input_blocks = nn.ModuleList([sampling.TimestepEmbedSequential1( | |
| operations.conv_nd(dims, in_channels, model_channels, 3, padding=1, dtype=self.dtype, device=device))]) | |
| self._feature_size, input_block_chans, ch, ds = model_channels, [model_channels], model_channels, 1 | |
| self.double_blocks = nn.ModuleList() | |
| def make_attn(ch, depth, context_dim, disable_self_attn=False): | |
| dim_head = ch // num_heads if num_head_channels == -1 else num_head_channels | |
| heads = num_heads if num_head_channels == -1 else ch // num_head_channels | |
| return transformer.SpatialTransformer(ch, heads, dim_head, depth=depth, context_dim=context_dim, | |
| disable_self_attn=disable_self_attn, use_linear=use_linear_in_transformer, | |
| use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations) | |
| def make_res(ch, out_ch=None, down=False, up=False): | |
| return ResBlock.ResBlock1(channels=ch, emb_channels=time_embed_dim, dropout=dropout, | |
| out_channels=out_ch, use_checkpoint=use_checkpoint, dims=dims, | |
| use_scale_shift_norm=use_scale_shift_norm, down=down, up=up, | |
| dtype=self.dtype, device=device, operations=operations) | |
| for level, mult in enumerate(channel_mult): | |
| for nr in range(self.num_res_blocks[level]): | |
| layers = [make_res(ch, mult * model_channels)] | |
| ch = mult * model_channels | |
| num_trans = transformer_depth.pop(0) | |
| if num_trans > 0 and (not util.exists(num_attention_blocks) or nr < num_attention_blocks[level]): | |
| layers.append(make_attn(ch, num_trans, context_dim)) | |
| self.input_blocks.append(sampling.TimestepEmbedSequential1(*layers)) | |
| self._feature_size += ch | |
| input_block_chans.append(ch) | |
| if level != len(channel_mult) - 1: | |
| out_ch = ch | |
| self.input_blocks.append(sampling.TimestepEmbedSequential1( | |
| make_res(ch, out_ch, down=True) if resblock_updown else | |
| ResBlock.Downsample1(ch, conv_resample, dims=dims, out_channels=out_ch, dtype=self.dtype, device=device, operations=operations))) | |
| ch, input_block_chans, ds = out_ch, input_block_chans + [out_ch], ds * 2 | |
| self._feature_size += ch | |
| dim_head = ch // num_heads if num_head_channels == -1 else num_head_channels | |
| mid_block = [make_res(ch)] | |
| self.middle_block = None | |
| if transformer_depth_middle >= -1: | |
| if transformer_depth_middle >= 0: | |
| mid_block += [make_attn(ch, transformer_depth_middle, context_dim, disable_middle_self_attn), make_res(ch)] | |
| self.middle_block = sampling.TimestepEmbedSequential1(*mid_block) | |
| self._feature_size += ch | |
| self.output_blocks = nn.ModuleList([]) | |
| for level, mult in list(enumerate(channel_mult))[::-1]: | |
| for i in range(self.num_res_blocks[level] + 1): | |
| ich = input_block_chans.pop() | |
| layers = [make_res(ch + ich, model_channels * mult)] | |
| ch = model_channels * mult | |
| num_trans = transformer_depth_output.pop() | |
| if num_trans > 0 and (not util.exists(num_attention_blocks) or i < num_attention_blocks[level]): | |
| layers.append(make_attn(ch, num_trans, context_dim)) | |
| if level and i == self.num_res_blocks[level]: | |
| layers.append(make_res(ch, ch, up=True) if resblock_updown else | |
| ResBlock.Upsample1(ch, conv_resample, dims=dims, out_channels=ch, dtype=self.dtype, device=device, operations=operations)) | |
| ds //= 2 | |
| self.output_blocks.append(sampling.TimestepEmbedSequential1(*layers)) | |
| self._feature_size += ch | |
| self.out = nn.Sequential(operations.GroupNorm(32, ch, dtype=self.dtype, device=device), nn.SiLU(), | |
| util.zero_module(operations.conv_nd(dims, model_channels, out_channels, 3, padding=1, dtype=self.dtype, device=device))) | |
| def forward(self, x: torch.Tensor, timesteps: Optional[torch.Tensor] = None, context: Optional[torch.Tensor] = None, | |
| y: Optional[torch.Tensor] = None, control: Optional[torch.Tensor] = None, | |
| transformer_options: Dict[str, Any] = {}, **kwargs: Any) -> torch.Tensor: | |
| transformer_options["original_shape"], transformer_options["transformer_index"] = list(x.shape), 0 | |
| num_video_frames = kwargs.get("num_video_frames", self.default_num_video_frames) | |
| image_only_indicator, time_context = kwargs.get("image_only_indicator"), kwargs.get("time_context") | |
| if self.num_classes is None: | |
| y = None | |
| elif y is None: | |
| raise ValueError("y is required for models with num_classes") | |
| emb = self.time_embed(sampling_util.timestep_embedding(timesteps, self.model_channels).to(device=x.device, dtype=x.dtype)) | |
| if y is not None: | |
| emb = emb + self.label_emb(y.to(device=x.device, dtype=x.dtype)) | |
| hs, h = [], x | |
| for id, module in enumerate(self.input_blocks): | |
| transformer_options["block"] = ("input", id) | |
| h = ResBlock.forward_timestep_embed1(module, h, emb, context, transformer_options, | |
| time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator) | |
| h = apply_control1(h, control, "input") | |
| hs.append(h) | |
| transformer_options["block"] = ("middle", 0) | |
| if self.middle_block is not None: | |
| h = ResBlock.forward_timestep_embed1(self.middle_block, h, emb, context, transformer_options, | |
| time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator) | |
| h = apply_control1(h, control, "middle") | |
| for id, module in enumerate(self.output_blocks): | |
| transformer_options["block"] = ("output", id) | |
| hsp = apply_control1(hs.pop(), control, "output") | |
| h = torch.cat([h, hsp], dim=1) | |
| del hsp | |
| h = ResBlock.forward_timestep_embed1(module, h, emb, context, transformer_options, | |
| hs[-1].shape if hs else None, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator) | |
| return self.out(h.type(x.dtype)) | |
| def detect_unet_config(state_dict: Dict[str, torch.Tensor], key_prefix: str) -> Dict[str, Any]: | |
| state_dict_keys = list(state_dict.keys()) | |
| # MMDIT model | |
| if f"{key_prefix}joint_blocks.0.context_block.attn.qkv.weight" in state_dict_keys: | |
| cfg = {"in_channels": state_dict[f"{key_prefix}x_embedder.proj.weight"].shape[1], | |
| "patch_size": state_dict[f"{key_prefix}x_embedder.proj.weight"].shape[2], | |
| "depth": state_dict[f"{key_prefix}x_embedder.proj.weight"].shape[0] // 64, "input_size": None} | |
| if f"{key_prefix}final_layer.linear.weight" in state_dict: | |
| cfg["out_channels"] = state_dict[f"{key_prefix}final_layer.linear.weight"].shape[0] // (cfg["patch_size"] ** 2) | |
| if f"{key_prefix}y_embedder.mlp.0.weight" in state_dict_keys: | |
| cfg["adm_in_channels"] = state_dict[f"{key_prefix}y_embedder.mlp.0.weight"].shape[1] | |
| if f"{key_prefix}context_embedder.weight" in state_dict_keys: | |
| w = state_dict[f"{key_prefix}context_embedder.weight"] | |
| cfg["context_embedder_config"] = {"target": "torch.nn.Linear", "params": {"in_features": w.shape[1], "out_features": w.shape[0]}} | |
| if f"{key_prefix}pos_embed" in state_dict_keys: | |
| cfg["num_patches"] = state_dict[f"{key_prefix}pos_embed"].shape[1] | |
| cfg["pos_embed_max_size"] = round(math.sqrt(cfg["num_patches"])) | |
| if f"{key_prefix}joint_blocks.0.context_block.attn.ln_q.weight" in state_dict_keys: | |
| cfg["qk_norm"] = "rms" | |
| cfg["pos_embed_scaling_factor"] = None | |
| if f"{key_prefix}context_processor.layers.0.attn.qkv.weight" in state_dict_keys: | |
| cfg["context_processor_layers"] = transformer.count_blocks(state_dict_keys, f"{key_prefix}context_processor.layers." + "{}.") | |
| return cfg | |
| # Stable Cascade | |
| if f"{key_prefix}clf.1.weight" in state_dict_keys: | |
| cfg = {} | |
| if f"{key_prefix}clip_txt_mapper.weight" in state_dict_keys: | |
| cfg["stable_cascade_stage"] = "c" | |
| w = state_dict[f"{key_prefix}clip_txt_mapper.weight"] | |
| if w.shape[0] == 1536: | |
| cfg.update({"c_cond": 1536, "c_hidden": [1536, 1536], "nhead": [24, 24], "blocks": [[4, 12], [12, 4]]}) | |
| elif w.shape[0] == 2048: | |
| cfg["c_cond"] = 2048 | |
| elif f"{key_prefix}clip_mapper.weight" in state_dict_keys: | |
| cfg["stable_cascade_stage"] = "b" | |
| w = state_dict[f"{key_prefix}down_blocks.1.0.channelwise.0.weight"] | |
| if w.shape[-1] == 640: | |
| cfg.update({"c_hidden": [320, 640, 1280, 1280], "nhead": [-1, -1, 20, 20], "blocks": [[2, 6, 28, 6], [6, 28, 6, 2]], "block_repeat": [[1, 1, 1, 1], [3, 3, 2, 2]]}) | |
| elif w.shape[-1] == 576: | |
| cfg.update({"c_hidden": [320, 576, 1152, 1152], "nhead": [-1, 9, 18, 18], "blocks": [[2, 4, 14, 4], [4, 14, 4, 2]], "block_repeat": [[1, 1, 1, 1], [2, 2, 2, 2]]}) | |
| return cfg | |
| # Stable Audio DIT | |
| if f"{key_prefix}transformer.rotary_pos_emb.inv_freq" in state_dict_keys: | |
| return {"audio_model": "dit1.0"} | |
| # Aura Flow DIT | |
| if f"{key_prefix}double_layers.0.attn.w1q.weight" in state_dict_keys: | |
| double_layers = transformer.count_blocks(state_dict_keys, f"{key_prefix}double_layers." + "{}.") | |
| single_layers = transformer.count_blocks(state_dict_keys, f"{key_prefix}single_layers." + "{}.") | |
| return {"max_seq": state_dict[f"{key_prefix}positional_encoding"].shape[1], | |
| "cond_seq_dim": state_dict[f"{key_prefix}cond_seq_linear.weight"].shape[1], | |
| "n_double_layers": double_layers, "n_layers": double_layers + single_layers} | |
| # Hunyuan DiT | |
| if f"{key_prefix}mlp_t5.0.weight" in state_dict_keys: | |
| cfg = {"image_model": "hydit", "depth": transformer.count_blocks(state_dict_keys, f"{key_prefix}blocks." + "{}."), | |
| "hidden_size": state_dict[f"{key_prefix}x_embedder.proj.weight"].shape[0]} | |
| if cfg["hidden_size"] == 1408 and cfg["depth"] == 40: | |
| cfg["mlp_ratio"] = 4.3637 | |
| if state_dict[f"{key_prefix}extra_embedder.0.weight"].shape[1] == 3968: | |
| cfg.update({"size_cond": True, "use_style_cond": True, "image_model": "hydit1"}) | |
| return cfg | |
| # Flux | |
| if f"{key_prefix}double_blocks.0.img_attn.norm.key_norm.scale" in state_dict_keys: | |
| return {"image_model": "flux", "in_channels": 16, "vec_in_dim": 768, "context_in_dim": 4096, | |
| "hidden_size": 3072, "mlp_ratio": 4.0, "num_heads": 24, | |
| "depth": transformer.count_blocks(state_dict_keys, f"{key_prefix}double_blocks." + "{}."), | |
| "depth_single_blocks": transformer.count_blocks(state_dict_keys, f"{key_prefix}single_blocks." + "{}."), | |
| "axes_dim": [16, 56, 56], "theta": 10000, "qkv_bias": True, | |
| "guidance_embed": f"{key_prefix}guidance_in.in_layer.weight" in state_dict_keys} | |
| if f"{key_prefix}input_blocks.0.0.weight" not in state_dict_keys: | |
| return None | |
| # Standard UNet | |
| cfg = {"use_checkpoint": False, "image_size": 32, "use_spatial_transformer": True, "legacy": False} | |
| if f"{key_prefix}label_emb.0.0.weight" in state_dict_keys: | |
| cfg["num_classes"], cfg["adm_in_channels"] = "sequential", state_dict[f"{key_prefix}label_emb.0.0.weight"].shape[1] | |
| else: | |
| cfg["adm_in_channels"] = None | |
| model_channels = state_dict[f"{key_prefix}input_blocks.0.0.weight"].shape[0] | |
| in_channels = state_dict[f"{key_prefix}input_blocks.0.0.weight"].shape[1] | |
| out_channels = state_dict.get(f"{key_prefix}out.2.weight", torch.zeros(4)).shape[0] or 4 | |
| num_res_blocks, channel_mult, transformer_depth, transformer_depth_output = [], [], [], [] | |
| context_dim, use_linear_in_transformer = None, False | |
| current_res, last_res_blocks, last_channel_mult = 1, 0, 0 | |
| input_block_count = transformer.count_blocks(state_dict_keys, f"{key_prefix}input_blocks" + ".{}.") | |
| for count in range(input_block_count): | |
| prefix = f"{key_prefix}input_blocks.{count}." | |
| prefix_output = f"{key_prefix}output_blocks.{input_block_count - count - 1}." | |
| block_keys = sorted([k for k in state_dict_keys if k.startswith(prefix)]) | |
| if not block_keys: | |
| break | |
| block_keys_output = sorted([k for k in state_dict_keys if k.startswith(prefix_output)]) | |
| if f"{prefix}0.op.weight" in block_keys: | |
| num_res_blocks.append(last_res_blocks) | |
| channel_mult.append(last_channel_mult) | |
| current_res *= 2 | |
| last_res_blocks, last_channel_mult = 0, 0 | |
| out = transformer.calculate_transformer_depth(prefix_output, state_dict_keys, state_dict) | |
| transformer_depth_output.append(out[0] if out else 0) | |
| else: | |
| if f"{prefix}0.in_layers.0.weight" in block_keys: | |
| last_res_blocks += 1 | |
| last_channel_mult = state_dict[f"{prefix}0.out_layers.3.weight"].shape[0] // model_channels | |
| out = transformer.calculate_transformer_depth(prefix, state_dict_keys, state_dict) | |
| if out: | |
| transformer_depth.append(out[0]) | |
| if context_dim is None: | |
| context_dim, use_linear_in_transformer = out[1], out[2] | |
| else: | |
| transformer_depth.append(0) | |
| if f"{prefix_output}0.in_layers.0.weight" in block_keys_output: | |
| out = transformer.calculate_transformer_depth(prefix_output, state_dict_keys, state_dict) | |
| transformer_depth_output.append(out[0] if out else 0) | |
| num_res_blocks.append(last_res_blocks) | |
| channel_mult.append(last_channel_mult) | |
| if f"{key_prefix}middle_block.1.proj_in.weight" in state_dict_keys: | |
| transformer_depth_middle = transformer.count_blocks(state_dict_keys, f"{key_prefix}middle_block.1.transformer_blocks." + "{}") | |
| elif f"{key_prefix}middle_block.0.in_layers.0.weight" in state_dict_keys: | |
| transformer_depth_middle = -1 | |
| else: | |
| transformer_depth_middle = -2 | |
| cfg.update({"in_channels": in_channels, "out_channels": out_channels, "model_channels": model_channels, | |
| "num_res_blocks": num_res_blocks, "transformer_depth": transformer_depth, | |
| "transformer_depth_output": transformer_depth_output, "channel_mult": channel_mult, | |
| "transformer_depth_middle": transformer_depth_middle, "use_linear_in_transformer": use_linear_in_transformer, | |
| "context_dim": context_dim, "use_temporal_resblock": False, "use_temporal_attention": False}) | |
| return cfg | |
| def model_config_from_unet_config(unet_config: Dict[str, Any], state_dict: Optional[Dict[str, torch.Tensor]] = None) -> Any: | |
| from src.SD15 import SD15 | |
| for model_config in SD15.models: | |
| if model_config.matches(unet_config, state_dict): | |
| return model_config(unet_config) | |
| logging.error(f"no match {unet_config}") | |
| return None | |
| def model_config_from_unet(state_dict: Dict[str, torch.Tensor], unet_key_prefix: str, use_base_if_no_match: bool = False) -> Any: | |
| unet_config = detect_unet_config(state_dict, unet_key_prefix) | |
| return model_config_from_unet_config(unet_config, state_dict) if unet_config else None | |
| def unet_dtype1(device: Optional[torch.device] = None, model_params: int = 0, | |
| supported_dtypes: List[torch.dtype] = [torch.float16, torch.bfloat16, torch.float32]) -> torch.dtype: | |
| return torch.float16 | |