Spaces:
Runtime error
Runtime error
| UNET_MAP_ATTENTIONS = { | |
| "proj_in.weight", | |
| "proj_in.bias", | |
| "proj_out.weight", | |
| "proj_out.bias", | |
| "norm.weight", | |
| "norm.bias", | |
| } | |
| TRANSFORMER_BLOCKS = { | |
| "norm1.weight", | |
| "norm1.bias", | |
| "norm2.weight", | |
| "norm2.bias", | |
| "norm3.weight", | |
| "norm3.bias", | |
| "attn1.to_q.weight", | |
| "attn1.to_k.weight", | |
| "attn1.to_v.weight", | |
| "attn1.to_out.0.weight", | |
| "attn1.to_out.0.bias", | |
| "attn2.to_q.weight", | |
| "attn2.to_k.weight", | |
| "attn2.to_v.weight", | |
| "attn2.to_out.0.weight", | |
| "attn2.to_out.0.bias", | |
| "ff.net.0.proj.weight", | |
| "ff.net.0.proj.bias", | |
| "ff.net.2.weight", | |
| "ff.net.2.bias", | |
| } | |
| UNET_MAP_RESNET = { | |
| "in_layers.2.weight": "conv1.weight", | |
| "in_layers.2.bias": "conv1.bias", | |
| "emb_layers.1.weight": "time_emb_proj.weight", | |
| "emb_layers.1.bias": "time_emb_proj.bias", | |
| "out_layers.3.weight": "conv2.weight", | |
| "out_layers.3.bias": "conv2.bias", | |
| "skip_connection.weight": "conv_shortcut.weight", | |
| "skip_connection.bias": "conv_shortcut.bias", | |
| "in_layers.0.weight": "norm1.weight", | |
| "in_layers.0.bias": "norm1.bias", | |
| "out_layers.0.weight": "norm2.weight", | |
| "out_layers.0.bias": "norm2.bias", | |
| } | |
| UNET_MAP_BASIC = { | |
| ("label_emb.0.0.weight", "class_embedding.linear_1.weight"), | |
| ("label_emb.0.0.bias", "class_embedding.linear_1.bias"), | |
| ("label_emb.0.2.weight", "class_embedding.linear_2.weight"), | |
| ("label_emb.0.2.bias", "class_embedding.linear_2.bias"), | |
| ("label_emb.0.0.weight", "add_embedding.linear_1.weight"), | |
| ("label_emb.0.0.bias", "add_embedding.linear_1.bias"), | |
| ("label_emb.0.2.weight", "add_embedding.linear_2.weight"), | |
| ("label_emb.0.2.bias", "add_embedding.linear_2.bias"), | |
| ("input_blocks.0.0.weight", "conv_in.weight"), | |
| ("input_blocks.0.0.bias", "conv_in.bias"), | |
| ("out.0.weight", "conv_norm_out.weight"), | |
| ("out.0.bias", "conv_norm_out.bias"), | |
| ("out.2.weight", "conv_out.weight"), | |
| ("out.2.bias", "conv_out.bias"), | |
| ("time_embed.0.weight", "time_embedding.linear_1.weight"), | |
| ("time_embed.0.bias", "time_embedding.linear_1.bias"), | |
| ("time_embed.2.weight", "time_embedding.linear_2.weight"), | |
| ("time_embed.2.bias", "time_embedding.linear_2.bias") | |
| } | |
| def unet_to_diffusers(unet_config): | |
| if "num_res_blocks" not in unet_config: | |
| return {} | |
| num_res_blocks = unet_config["num_res_blocks"] | |
| channel_mult = unet_config["channel_mult"] | |
| transformer_depth = unet_config["transformer_depth"][:] | |
| transformer_depth_output = unet_config["transformer_depth_output"][:] | |
| num_blocks = len(channel_mult) | |
| transformers_mid = unet_config.get("transformer_depth_middle", None) | |
| diffusers_unet_map = {} | |
| for x in range(num_blocks): | |
| n = 1 + (num_res_blocks[x] + 1) * x | |
| for i in range(num_res_blocks[x]): | |
| for b in UNET_MAP_RESNET: | |
| diffusers_unet_map["down_blocks.{}.resnets.{}.{}".format(x, i, UNET_MAP_RESNET[b])] = "input_blocks.{}.0.{}".format(n, b) | |
| num_transformers = transformer_depth.pop(0) | |
| if num_transformers > 0: | |
| for b in UNET_MAP_ATTENTIONS: | |
| diffusers_unet_map["down_blocks.{}.attentions.{}.{}".format(x, i, b)] = "input_blocks.{}.1.{}".format(n, b) | |
| for t in range(num_transformers): | |
| for b in TRANSFORMER_BLOCKS: | |
| diffusers_unet_map["down_blocks.{}.attentions.{}.transformer_blocks.{}.{}".format(x, i, t, b)] = "input_blocks.{}.1.transformer_blocks.{}.{}".format(n, t, b) | |
| n += 1 | |
| for k in ["weight", "bias"]: | |
| diffusers_unet_map["down_blocks.{}.downsamplers.0.conv.{}".format(x, k)] = "input_blocks.{}.0.op.{}".format(n, k) | |
| i = 0 | |
| for b in UNET_MAP_ATTENTIONS: | |
| diffusers_unet_map["mid_block.attentions.{}.{}".format(i, b)] = "middle_block.1.{}".format(b) | |
| for t in range(transformers_mid): | |
| for b in TRANSFORMER_BLOCKS: | |
| diffusers_unet_map["mid_block.attentions.{}.transformer_blocks.{}.{}".format(i, t, b)] = "middle_block.1.transformer_blocks.{}.{}".format(t, b) | |
| for i, n in enumerate([0, 2]): | |
| for b in UNET_MAP_RESNET: | |
| diffusers_unet_map["mid_block.resnets.{}.{}".format(i, UNET_MAP_RESNET[b])] = "middle_block.{}.{}".format(n, b) | |
| num_res_blocks = list(reversed(num_res_blocks)) | |
| for x in range(num_blocks): | |
| n = (num_res_blocks[x] + 1) * x | |
| l = num_res_blocks[x] + 1 | |
| for i in range(l): | |
| c = 0 | |
| for b in UNET_MAP_RESNET: | |
| diffusers_unet_map["up_blocks.{}.resnets.{}.{}".format(x, i, UNET_MAP_RESNET[b])] = "output_blocks.{}.0.{}".format(n, b) | |
| c += 1 | |
| num_transformers = transformer_depth_output.pop() | |
| if num_transformers > 0: | |
| c += 1 | |
| for b in UNET_MAP_ATTENTIONS: | |
| diffusers_unet_map["up_blocks.{}.attentions.{}.{}".format(x, i, b)] = "output_blocks.{}.1.{}".format(n, b) | |
| for t in range(num_transformers): | |
| for b in TRANSFORMER_BLOCKS: | |
| diffusers_unet_map["up_blocks.{}.attentions.{}.transformer_blocks.{}.{}".format(x, i, t, b)] = "output_blocks.{}.1.transformer_blocks.{}.{}".format(n, t, b) | |
| if i == l - 1: | |
| for k in ["weight", "bias"]: | |
| diffusers_unet_map["up_blocks.{}.upsamplers.0.conv.{}".format(x, k)] = "output_blocks.{}.{}.conv.{}".format(n, c, k) | |
| n += 1 | |
| for k in UNET_MAP_BASIC: | |
| diffusers_unet_map[k[1]] = k[0] | |
| return diffusers_unet_map | |