Aatricks's picture
Deploy ZeroGPU Gradio Space snapshot
b701455
import logging
import math
from typing import Any, Dict, List, Optional
import torch.nn as nn
import torch
from src.Utilities import util
from src.AutoEncoders import ResBlock
from src.NeuralNetwork import transformer
from src.cond import cast
from src.sample import sampling, sampling_util
UNET_MAP_ATTENTIONS = {"proj_in.weight", "proj_in.bias", "proj_out.weight", "proj_out.bias", "norm.weight", "norm.bias"}
TRANSFORMER_BLOCKS = {
"norm1.weight", "norm1.bias", "norm2.weight", "norm2.bias", "norm3.weight", "norm3.bias",
"attn1.to_q.weight", "attn1.to_k.weight", "attn1.to_v.weight", "attn1.to_out.0.weight", "attn1.to_out.0.bias",
"attn2.to_q.weight", "attn2.to_k.weight", "attn2.to_v.weight", "attn2.to_out.0.weight", "attn2.to_out.0.bias",
"ff.net.0.proj.weight", "ff.net.0.proj.bias", "ff.net.2.weight", "ff.net.2.bias",
}
UNET_MAP_RESNET = {
"in_layers.2.weight": "conv1.weight", "in_layers.2.bias": "conv1.bias",
"emb_layers.1.weight": "time_emb_proj.weight", "emb_layers.1.bias": "time_emb_proj.bias",
"out_layers.3.weight": "conv2.weight", "out_layers.3.bias": "conv2.bias",
"skip_connection.weight": "conv_shortcut.weight", "skip_connection.bias": "conv_shortcut.bias",
"in_layers.0.weight": "norm1.weight", "in_layers.0.bias": "norm1.bias",
"out_layers.0.weight": "norm2.weight", "out_layers.0.bias": "norm2.bias",
}
UNET_MAP_BASIC = {
("label_emb.0.0.weight", "class_embedding.linear_1.weight"), ("label_emb.0.0.bias", "class_embedding.linear_1.bias"),
("label_emb.0.2.weight", "class_embedding.linear_2.weight"), ("label_emb.0.2.bias", "class_embedding.linear_2.bias"),
("label_emb.0.0.weight", "add_embedding.linear_1.weight"), ("label_emb.0.0.bias", "add_embedding.linear_1.bias"),
("label_emb.0.2.weight", "add_embedding.linear_2.weight"), ("label_emb.0.2.bias", "add_embedding.linear_2.bias"),
("input_blocks.0.0.weight", "conv_in.weight"), ("input_blocks.0.0.bias", "conv_in.bias"),
("out.0.weight", "conv_norm_out.weight"), ("out.0.bias", "conv_norm_out.bias"),
("out.2.weight", "conv_out.weight"), ("out.2.bias", "conv_out.bias"),
("time_embed.0.weight", "time_embedding.linear_1.weight"), ("time_embed.0.bias", "time_embedding.linear_1.bias"),
("time_embed.2.weight", "time_embedding.linear_2.weight"), ("time_embed.2.bias", "time_embedding.linear_2.bias"),
}
oai_ops = cast.disable_weight_init
def unet_to_diffusers(unet_config: dict) -> dict:
if "num_res_blocks" not in unet_config:
return {}
num_res_blocks, channel_mult = unet_config["num_res_blocks"], unet_config["channel_mult"]
transformer_depth, transformer_depth_output = unet_config["transformer_depth"][:], unet_config["transformer_depth_output"][:]
num_blocks, transformers_mid = len(channel_mult), unet_config.get("transformer_depth_middle", None)
diffusers_unet_map = {}
for x in range(num_blocks):
n = 1 + (num_res_blocks[x] + 1) * x
for i in range(num_res_blocks[x]):
for b in UNET_MAP_RESNET:
diffusers_unet_map[f"down_blocks.{x}.resnets.{i}.{UNET_MAP_RESNET[b]}"] = f"input_blocks.{n}.0.{b}"
num_transformers = transformer_depth.pop(0)
if num_transformers > 0:
for b in UNET_MAP_ATTENTIONS:
diffusers_unet_map[f"down_blocks.{x}.attentions.{i}.{b}"] = f"input_blocks.{n}.1.{b}"
for t in range(num_transformers):
for b in TRANSFORMER_BLOCKS:
diffusers_unet_map[f"down_blocks.{x}.attentions.{i}.transformer_blocks.{t}.{b}"] = f"input_blocks.{n}.1.transformer_blocks.{t}.{b}"
n += 1
for k in ["weight", "bias"]:
diffusers_unet_map[f"down_blocks.{x}.downsamplers.0.conv.{k}"] = f"input_blocks.{n}.0.op.{k}"
for b in UNET_MAP_ATTENTIONS:
diffusers_unet_map[f"mid_block.attentions.0.{b}"] = f"middle_block.1.{b}"
for t in range(transformers_mid):
for b in TRANSFORMER_BLOCKS:
diffusers_unet_map[f"mid_block.attentions.0.transformer_blocks.{t}.{b}"] = f"middle_block.1.transformer_blocks.{t}.{b}"
for i, n in enumerate([0, 2]):
for b in UNET_MAP_RESNET:
diffusers_unet_map[f"mid_block.resnets.{i}.{UNET_MAP_RESNET[b]}"] = f"middle_block.{n}.{b}"
num_res_blocks = list(reversed(num_res_blocks))
for x in range(num_blocks):
n = (num_res_blocks[x] + 1) * x
for i in range(num_res_blocks[x] + 1):
c = 0
for b in UNET_MAP_RESNET:
diffusers_unet_map[f"up_blocks.{x}.resnets.{i}.{UNET_MAP_RESNET[b]}"] = f"output_blocks.{n}.0.{b}"
c += 1
num_transformers = transformer_depth_output.pop()
if num_transformers > 0:
c += 1
for b in UNET_MAP_ATTENTIONS:
diffusers_unet_map[f"up_blocks.{x}.attentions.{i}.{b}"] = f"output_blocks.{n}.1.{b}"
for t in range(num_transformers):
for b in TRANSFORMER_BLOCKS:
diffusers_unet_map[f"up_blocks.{x}.attentions.{i}.transformer_blocks.{t}.{b}"] = f"output_blocks.{n}.1.transformer_blocks.{t}.{b}"
if i == num_res_blocks[x]:
for k in ["weight", "bias"]:
diffusers_unet_map[f"up_blocks.{x}.upsamplers.0.conv.{k}"] = f"output_blocks.{n}.{c}.conv.{k}"
n += 1
for k in UNET_MAP_BASIC:
diffusers_unet_map[k[1]] = k[0]
return diffusers_unet_map
def apply_control1(h: torch.Tensor, control: any, name: str) -> torch.Tensor:
return h
class UNetModel1(nn.Module):
def __init__(self, image_size: int, in_channels: int, model_channels: int, out_channels: int, num_res_blocks: list,
dropout: float = 0, channel_mult: tuple = (1, 2, 4, 8), conv_resample: bool = True, dims: int = 2,
num_classes: int = None, use_checkpoint: bool = False, dtype: torch.dtype = torch.float32,
num_heads: int = -1, num_head_channels: int = -1, num_heads_upsample: int = -1,
use_scale_shift_norm: bool = False, resblock_updown: bool = False, use_new_attention_order: bool = False,
use_spatial_transformer: bool = False, transformer_depth: int = 1, context_dim: int = None,
n_embed: int = None, legacy: bool = True, disable_self_attentions: list = None,
num_attention_blocks: list = None, disable_middle_self_attn: bool = False,
use_linear_in_transformer: bool = False, adm_in_channels: int = None,
transformer_depth_middle: int = None, transformer_depth_output: list = None,
use_temporal_resblock: bool = False, use_temporal_attention: bool = False,
time_context_dim: int = None, extra_ff_mix_layer: bool = False, use_spatial_context: bool = False,
merge_strategy: any = None, merge_factor: float = 0.0, video_kernel_size: int = None,
disable_temporal_crossattention: bool = False, max_ddpm_temb_period: int = 10000,
device: torch.device = None, operations: any = oai_ops):
super().__init__()
if context_dim is not None:
self.context_dim = context_dim
if num_heads_upsample == -1:
num_heads_upsample = num_heads
if num_head_channels == -1:
assert num_heads != -1
self.in_channels, self.model_channels, self.out_channels = in_channels, model_channels, out_channels
self.num_res_blocks, self.dropout, self.channel_mult = num_res_blocks, dropout, channel_mult
self.conv_resample, self.num_classes, self.use_checkpoint = conv_resample, num_classes, use_checkpoint
self.dtype, self.num_heads, self.num_head_channels = dtype, num_heads, num_head_channels
self.num_heads_upsample, self.use_temporal_resblocks = num_heads_upsample, use_temporal_resblock
self.predict_codebook_ids, self.default_num_video_frames = n_embed is not None, None
transformer_depth, transformer_depth_output = transformer_depth[:], transformer_depth_output[:]
time_embed_dim = model_channels * 4
self.time_embed = nn.Sequential(
operations.Linear(model_channels, time_embed_dim, dtype=self.dtype, device=device), nn.SiLU(),
operations.Linear(time_embed_dim, time_embed_dim, dtype=self.dtype, device=device))
if adm_in_channels is not None:
self.label_emb = nn.Sequential(
nn.Sequential(
operations.Linear(adm_in_channels, time_embed_dim, dtype=self.dtype, device=device),
nn.SiLU(),
operations.Linear(time_embed_dim, time_embed_dim, dtype=self.dtype, device=device)))
self.input_blocks = nn.ModuleList([sampling.TimestepEmbedSequential1(
operations.conv_nd(dims, in_channels, model_channels, 3, padding=1, dtype=self.dtype, device=device))])
self._feature_size, input_block_chans, ch, ds = model_channels, [model_channels], model_channels, 1
self.double_blocks = nn.ModuleList()
def make_attn(ch, depth, context_dim, disable_self_attn=False):
dim_head = ch // num_heads if num_head_channels == -1 else num_head_channels
heads = num_heads if num_head_channels == -1 else ch // num_head_channels
return transformer.SpatialTransformer(ch, heads, dim_head, depth=depth, context_dim=context_dim,
disable_self_attn=disable_self_attn, use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations)
def make_res(ch, out_ch=None, down=False, up=False):
return ResBlock.ResBlock1(channels=ch, emb_channels=time_embed_dim, dropout=dropout,
out_channels=out_ch, use_checkpoint=use_checkpoint, dims=dims,
use_scale_shift_norm=use_scale_shift_norm, down=down, up=up,
dtype=self.dtype, device=device, operations=operations)
for level, mult in enumerate(channel_mult):
for nr in range(self.num_res_blocks[level]):
layers = [make_res(ch, mult * model_channels)]
ch = mult * model_channels
num_trans = transformer_depth.pop(0)
if num_trans > 0 and (not util.exists(num_attention_blocks) or nr < num_attention_blocks[level]):
layers.append(make_attn(ch, num_trans, context_dim))
self.input_blocks.append(sampling.TimestepEmbedSequential1(*layers))
self._feature_size += ch
input_block_chans.append(ch)
if level != len(channel_mult) - 1:
out_ch = ch
self.input_blocks.append(sampling.TimestepEmbedSequential1(
make_res(ch, out_ch, down=True) if resblock_updown else
ResBlock.Downsample1(ch, conv_resample, dims=dims, out_channels=out_ch, dtype=self.dtype, device=device, operations=operations)))
ch, input_block_chans, ds = out_ch, input_block_chans + [out_ch], ds * 2
self._feature_size += ch
dim_head = ch // num_heads if num_head_channels == -1 else num_head_channels
mid_block = [make_res(ch)]
self.middle_block = None
if transformer_depth_middle >= -1:
if transformer_depth_middle >= 0:
mid_block += [make_attn(ch, transformer_depth_middle, context_dim, disable_middle_self_attn), make_res(ch)]
self.middle_block = sampling.TimestepEmbedSequential1(*mid_block)
self._feature_size += ch
self.output_blocks = nn.ModuleList([])
for level, mult in list(enumerate(channel_mult))[::-1]:
for i in range(self.num_res_blocks[level] + 1):
ich = input_block_chans.pop()
layers = [make_res(ch + ich, model_channels * mult)]
ch = model_channels * mult
num_trans = transformer_depth_output.pop()
if num_trans > 0 and (not util.exists(num_attention_blocks) or i < num_attention_blocks[level]):
layers.append(make_attn(ch, num_trans, context_dim))
if level and i == self.num_res_blocks[level]:
layers.append(make_res(ch, ch, up=True) if resblock_updown else
ResBlock.Upsample1(ch, conv_resample, dims=dims, out_channels=ch, dtype=self.dtype, device=device, operations=operations))
ds //= 2
self.output_blocks.append(sampling.TimestepEmbedSequential1(*layers))
self._feature_size += ch
self.out = nn.Sequential(operations.GroupNorm(32, ch, dtype=self.dtype, device=device), nn.SiLU(),
util.zero_module(operations.conv_nd(dims, model_channels, out_channels, 3, padding=1, dtype=self.dtype, device=device)))
def forward(self, x: torch.Tensor, timesteps: Optional[torch.Tensor] = None, context: Optional[torch.Tensor] = None,
y: Optional[torch.Tensor] = None, control: Optional[torch.Tensor] = None,
transformer_options: Dict[str, Any] = {}, **kwargs: Any) -> torch.Tensor:
transformer_options["original_shape"], transformer_options["transformer_index"] = list(x.shape), 0
num_video_frames = kwargs.get("num_video_frames", self.default_num_video_frames)
image_only_indicator, time_context = kwargs.get("image_only_indicator"), kwargs.get("time_context")
if self.num_classes is None:
y = None
elif y is None:
raise ValueError("y is required for models with num_classes")
emb = self.time_embed(sampling_util.timestep_embedding(timesteps, self.model_channels).to(device=x.device, dtype=x.dtype))
if y is not None:
emb = emb + self.label_emb(y.to(device=x.device, dtype=x.dtype))
hs, h = [], x
for id, module in enumerate(self.input_blocks):
transformer_options["block"] = ("input", id)
h = ResBlock.forward_timestep_embed1(module, h, emb, context, transformer_options,
time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator)
h = apply_control1(h, control, "input")
hs.append(h)
transformer_options["block"] = ("middle", 0)
if self.middle_block is not None:
h = ResBlock.forward_timestep_embed1(self.middle_block, h, emb, context, transformer_options,
time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator)
h = apply_control1(h, control, "middle")
for id, module in enumerate(self.output_blocks):
transformer_options["block"] = ("output", id)
hsp = apply_control1(hs.pop(), control, "output")
h = torch.cat([h, hsp], dim=1)
del hsp
h = ResBlock.forward_timestep_embed1(module, h, emb, context, transformer_options,
hs[-1].shape if hs else None, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator)
return self.out(h.type(x.dtype))
def detect_unet_config(state_dict: Dict[str, torch.Tensor], key_prefix: str) -> Dict[str, Any]:
state_dict_keys = list(state_dict.keys())
# MMDIT model
if f"{key_prefix}joint_blocks.0.context_block.attn.qkv.weight" in state_dict_keys:
cfg = {"in_channels": state_dict[f"{key_prefix}x_embedder.proj.weight"].shape[1],
"patch_size": state_dict[f"{key_prefix}x_embedder.proj.weight"].shape[2],
"depth": state_dict[f"{key_prefix}x_embedder.proj.weight"].shape[0] // 64, "input_size": None}
if f"{key_prefix}final_layer.linear.weight" in state_dict:
cfg["out_channels"] = state_dict[f"{key_prefix}final_layer.linear.weight"].shape[0] // (cfg["patch_size"] ** 2)
if f"{key_prefix}y_embedder.mlp.0.weight" in state_dict_keys:
cfg["adm_in_channels"] = state_dict[f"{key_prefix}y_embedder.mlp.0.weight"].shape[1]
if f"{key_prefix}context_embedder.weight" in state_dict_keys:
w = state_dict[f"{key_prefix}context_embedder.weight"]
cfg["context_embedder_config"] = {"target": "torch.nn.Linear", "params": {"in_features": w.shape[1], "out_features": w.shape[0]}}
if f"{key_prefix}pos_embed" in state_dict_keys:
cfg["num_patches"] = state_dict[f"{key_prefix}pos_embed"].shape[1]
cfg["pos_embed_max_size"] = round(math.sqrt(cfg["num_patches"]))
if f"{key_prefix}joint_blocks.0.context_block.attn.ln_q.weight" in state_dict_keys:
cfg["qk_norm"] = "rms"
cfg["pos_embed_scaling_factor"] = None
if f"{key_prefix}context_processor.layers.0.attn.qkv.weight" in state_dict_keys:
cfg["context_processor_layers"] = transformer.count_blocks(state_dict_keys, f"{key_prefix}context_processor.layers." + "{}.")
return cfg
# Stable Cascade
if f"{key_prefix}clf.1.weight" in state_dict_keys:
cfg = {}
if f"{key_prefix}clip_txt_mapper.weight" in state_dict_keys:
cfg["stable_cascade_stage"] = "c"
w = state_dict[f"{key_prefix}clip_txt_mapper.weight"]
if w.shape[0] == 1536:
cfg.update({"c_cond": 1536, "c_hidden": [1536, 1536], "nhead": [24, 24], "blocks": [[4, 12], [12, 4]]})
elif w.shape[0] == 2048:
cfg["c_cond"] = 2048
elif f"{key_prefix}clip_mapper.weight" in state_dict_keys:
cfg["stable_cascade_stage"] = "b"
w = state_dict[f"{key_prefix}down_blocks.1.0.channelwise.0.weight"]
if w.shape[-1] == 640:
cfg.update({"c_hidden": [320, 640, 1280, 1280], "nhead": [-1, -1, 20, 20], "blocks": [[2, 6, 28, 6], [6, 28, 6, 2]], "block_repeat": [[1, 1, 1, 1], [3, 3, 2, 2]]})
elif w.shape[-1] == 576:
cfg.update({"c_hidden": [320, 576, 1152, 1152], "nhead": [-1, 9, 18, 18], "blocks": [[2, 4, 14, 4], [4, 14, 4, 2]], "block_repeat": [[1, 1, 1, 1], [2, 2, 2, 2]]})
return cfg
# Stable Audio DIT
if f"{key_prefix}transformer.rotary_pos_emb.inv_freq" in state_dict_keys:
return {"audio_model": "dit1.0"}
# Aura Flow DIT
if f"{key_prefix}double_layers.0.attn.w1q.weight" in state_dict_keys:
double_layers = transformer.count_blocks(state_dict_keys, f"{key_prefix}double_layers." + "{}.")
single_layers = transformer.count_blocks(state_dict_keys, f"{key_prefix}single_layers." + "{}.")
return {"max_seq": state_dict[f"{key_prefix}positional_encoding"].shape[1],
"cond_seq_dim": state_dict[f"{key_prefix}cond_seq_linear.weight"].shape[1],
"n_double_layers": double_layers, "n_layers": double_layers + single_layers}
# Hunyuan DiT
if f"{key_prefix}mlp_t5.0.weight" in state_dict_keys:
cfg = {"image_model": "hydit", "depth": transformer.count_blocks(state_dict_keys, f"{key_prefix}blocks." + "{}."),
"hidden_size": state_dict[f"{key_prefix}x_embedder.proj.weight"].shape[0]}
if cfg["hidden_size"] == 1408 and cfg["depth"] == 40:
cfg["mlp_ratio"] = 4.3637
if state_dict[f"{key_prefix}extra_embedder.0.weight"].shape[1] == 3968:
cfg.update({"size_cond": True, "use_style_cond": True, "image_model": "hydit1"})
return cfg
# Flux
if f"{key_prefix}double_blocks.0.img_attn.norm.key_norm.scale" in state_dict_keys:
return {"image_model": "flux", "in_channels": 16, "vec_in_dim": 768, "context_in_dim": 4096,
"hidden_size": 3072, "mlp_ratio": 4.0, "num_heads": 24,
"depth": transformer.count_blocks(state_dict_keys, f"{key_prefix}double_blocks." + "{}."),
"depth_single_blocks": transformer.count_blocks(state_dict_keys, f"{key_prefix}single_blocks." + "{}."),
"axes_dim": [16, 56, 56], "theta": 10000, "qkv_bias": True,
"guidance_embed": f"{key_prefix}guidance_in.in_layer.weight" in state_dict_keys}
if f"{key_prefix}input_blocks.0.0.weight" not in state_dict_keys:
return None
# Standard UNet
cfg = {"use_checkpoint": False, "image_size": 32, "use_spatial_transformer": True, "legacy": False}
if f"{key_prefix}label_emb.0.0.weight" in state_dict_keys:
cfg["num_classes"], cfg["adm_in_channels"] = "sequential", state_dict[f"{key_prefix}label_emb.0.0.weight"].shape[1]
else:
cfg["adm_in_channels"] = None
model_channels = state_dict[f"{key_prefix}input_blocks.0.0.weight"].shape[0]
in_channels = state_dict[f"{key_prefix}input_blocks.0.0.weight"].shape[1]
out_channels = state_dict.get(f"{key_prefix}out.2.weight", torch.zeros(4)).shape[0] or 4
num_res_blocks, channel_mult, transformer_depth, transformer_depth_output = [], [], [], []
context_dim, use_linear_in_transformer = None, False
current_res, last_res_blocks, last_channel_mult = 1, 0, 0
input_block_count = transformer.count_blocks(state_dict_keys, f"{key_prefix}input_blocks" + ".{}.")
for count in range(input_block_count):
prefix = f"{key_prefix}input_blocks.{count}."
prefix_output = f"{key_prefix}output_blocks.{input_block_count - count - 1}."
block_keys = sorted([k for k in state_dict_keys if k.startswith(prefix)])
if not block_keys:
break
block_keys_output = sorted([k for k in state_dict_keys if k.startswith(prefix_output)])
if f"{prefix}0.op.weight" in block_keys:
num_res_blocks.append(last_res_blocks)
channel_mult.append(last_channel_mult)
current_res *= 2
last_res_blocks, last_channel_mult = 0, 0
out = transformer.calculate_transformer_depth(prefix_output, state_dict_keys, state_dict)
transformer_depth_output.append(out[0] if out else 0)
else:
if f"{prefix}0.in_layers.0.weight" in block_keys:
last_res_blocks += 1
last_channel_mult = state_dict[f"{prefix}0.out_layers.3.weight"].shape[0] // model_channels
out = transformer.calculate_transformer_depth(prefix, state_dict_keys, state_dict)
if out:
transformer_depth.append(out[0])
if context_dim is None:
context_dim, use_linear_in_transformer = out[1], out[2]
else:
transformer_depth.append(0)
if f"{prefix_output}0.in_layers.0.weight" in block_keys_output:
out = transformer.calculate_transformer_depth(prefix_output, state_dict_keys, state_dict)
transformer_depth_output.append(out[0] if out else 0)
num_res_blocks.append(last_res_blocks)
channel_mult.append(last_channel_mult)
if f"{key_prefix}middle_block.1.proj_in.weight" in state_dict_keys:
transformer_depth_middle = transformer.count_blocks(state_dict_keys, f"{key_prefix}middle_block.1.transformer_blocks." + "{}")
elif f"{key_prefix}middle_block.0.in_layers.0.weight" in state_dict_keys:
transformer_depth_middle = -1
else:
transformer_depth_middle = -2
cfg.update({"in_channels": in_channels, "out_channels": out_channels, "model_channels": model_channels,
"num_res_blocks": num_res_blocks, "transformer_depth": transformer_depth,
"transformer_depth_output": transformer_depth_output, "channel_mult": channel_mult,
"transformer_depth_middle": transformer_depth_middle, "use_linear_in_transformer": use_linear_in_transformer,
"context_dim": context_dim, "use_temporal_resblock": False, "use_temporal_attention": False})
return cfg
def model_config_from_unet_config(unet_config: Dict[str, Any], state_dict: Optional[Dict[str, torch.Tensor]] = None) -> Any:
from src.SD15 import SD15
for model_config in SD15.models:
if model_config.matches(unet_config, state_dict):
return model_config(unet_config)
logging.error(f"no match {unet_config}")
return None
def model_config_from_unet(state_dict: Dict[str, torch.Tensor], unet_key_prefix: str, use_base_if_no_match: bool = False) -> Any:
unet_config = detect_unet_config(state_dict, unet_key_prefix)
return model_config_from_unet_config(unet_config, state_dict) if unet_config else None
def unet_dtype1(device: Optional[torch.device] = None, model_params: int = 0,
supported_dtypes: List[torch.dtype] = [torch.float16, torch.bfloat16, torch.float32]) -> torch.dtype:
return torch.float16