| | import torch |
| | import safetensors |
| | from accelerate import init_empty_weights |
| | from accelerate.utils.modeling import set_module_tensor_to_device |
| | from safetensors.torch import load_file, save_file |
| | from transformers import CLIPTextModel, CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer |
| | from typing import List |
| | from diffusers import AutoencoderKL, EulerDiscreteScheduler, UNet2DConditionModel |
| | from . import model_util |
| | from . import sdxl_original_unet |
| | from .utils import setup_logging |
| |
|
| | setup_logging() |
| | import logging |
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| | VAE_SCALE_FACTOR = 0.13025 |
| | MODEL_VERSION_SDXL_BASE_V1_0 = "sdxl_base_v1-0" |
| |
|
| | |
| | DIFFUSERS_REF_MODEL_ID_SDXL = "stabilityai/stable-diffusion-xl-base-1.0" |
| |
|
| | DIFFUSERS_SDXL_UNET_CONFIG = { |
| | "act_fn": "silu", |
| | "addition_embed_type": "text_time", |
| | "addition_embed_type_num_heads": 64, |
| | "addition_time_embed_dim": 256, |
| | "attention_head_dim": [5, 10, 20], |
| | "block_out_channels": [320, 640, 1280], |
| | "center_input_sample": False, |
| | "class_embed_type": None, |
| | "class_embeddings_concat": False, |
| | "conv_in_kernel": 3, |
| | "conv_out_kernel": 3, |
| | "cross_attention_dim": 2048, |
| | "cross_attention_norm": None, |
| | "down_block_types": ["DownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D"], |
| | "downsample_padding": 1, |
| | "dual_cross_attention": False, |
| | "encoder_hid_dim": None, |
| | "encoder_hid_dim_type": None, |
| | "flip_sin_to_cos": True, |
| | "freq_shift": 0, |
| | "in_channels": 4, |
| | "layers_per_block": 2, |
| | "mid_block_only_cross_attention": None, |
| | "mid_block_scale_factor": 1, |
| | "mid_block_type": "UNetMidBlock2DCrossAttn", |
| | "norm_eps": 1e-05, |
| | "norm_num_groups": 32, |
| | "num_attention_heads": None, |
| | "num_class_embeds": None, |
| | "only_cross_attention": False, |
| | "out_channels": 4, |
| | "projection_class_embeddings_input_dim": 2816, |
| | "resnet_out_scale_factor": 1.0, |
| | "resnet_skip_time_act": False, |
| | "resnet_time_scale_shift": "default", |
| | "sample_size": 128, |
| | "time_cond_proj_dim": None, |
| | "time_embedding_act_fn": None, |
| | "time_embedding_dim": None, |
| | "time_embedding_type": "positional", |
| | "timestep_post_act": None, |
| | "transformer_layers_per_block": [1, 2, 10], |
| | "up_block_types": ["CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "UpBlock2D"], |
| | "upcast_attention": False, |
| | "use_linear_projection": True, |
| | } |
| |
|
| |
|
| | def convert_sdxl_text_encoder_2_checkpoint(checkpoint, max_length): |
| | SDXL_KEY_PREFIX = "conditioner.embedders.1.model." |
| |
|
| | |
| | |
| | def convert_key(key): |
| | |
| | key = key.replace(SDXL_KEY_PREFIX + "transformer.", "text_model.encoder.") |
| | key = key.replace(SDXL_KEY_PREFIX, "text_model.") |
| |
|
| | if "resblocks" in key: |
| | |
| | key = key.replace(".resblocks.", ".layers.") |
| | if ".ln_" in key: |
| | key = key.replace(".ln_", ".layer_norm") |
| | elif ".mlp." in key: |
| | key = key.replace(".c_fc.", ".fc1.") |
| | key = key.replace(".c_proj.", ".fc2.") |
| | elif ".attn.out_proj" in key: |
| | key = key.replace(".attn.out_proj.", ".self_attn.out_proj.") |
| | elif ".attn.in_proj" in key: |
| | key = None |
| | else: |
| | raise ValueError(f"unexpected key in SD: {key}") |
| | elif ".positional_embedding" in key: |
| | key = key.replace(".positional_embedding", ".embeddings.position_embedding.weight") |
| | elif ".text_projection" in key: |
| | key = key.replace("text_model.text_projection", "text_projection.weight") |
| | elif ".logit_scale" in key: |
| | key = None |
| | elif ".token_embedding" in key: |
| | key = key.replace(".token_embedding.weight", ".embeddings.token_embedding.weight") |
| | elif ".ln_final" in key: |
| | key = key.replace(".ln_final", ".final_layer_norm") |
| | |
| | elif ".embeddings.position_ids" in key: |
| | key = None |
| | return key |
| |
|
| | keys = list(checkpoint.keys()) |
| | new_sd = {} |
| | for key in keys: |
| | new_key = convert_key(key) |
| | if new_key is None: |
| | continue |
| | new_sd[new_key] = checkpoint[key] |
| |
|
| | |
| | for key in keys: |
| | if ".resblocks" in key and ".attn.in_proj_" in key: |
| | |
| | values = torch.chunk(checkpoint[key], 3) |
| |
|
| | key_suffix = ".weight" if "weight" in key else ".bias" |
| | key_pfx = key.replace(SDXL_KEY_PREFIX + "transformer.resblocks.", "text_model.encoder.layers.") |
| | key_pfx = key_pfx.replace("_weight", "") |
| | key_pfx = key_pfx.replace("_bias", "") |
| | key_pfx = key_pfx.replace(".attn.in_proj", ".self_attn.") |
| | new_sd[key_pfx + "q_proj" + key_suffix] = values[0] |
| | new_sd[key_pfx + "k_proj" + key_suffix] = values[1] |
| | new_sd[key_pfx + "v_proj" + key_suffix] = values[2] |
| |
|
| | |
| | logit_scale = checkpoint.get(SDXL_KEY_PREFIX + "logit_scale", None) |
| |
|
| | |
| | if "text_projection.weight.weight" in new_sd: |
| | logger.info("convert_sdxl_text_encoder_2_checkpoint: convert text_projection.weight.weight to text_projection.weight") |
| | new_sd["text_projection.weight"] = new_sd["text_projection.weight.weight"] |
| | del new_sd["text_projection.weight.weight"] |
| |
|
| | return new_sd, logit_scale |
| |
|
| |
|
| | |
| | def _load_state_dict_on_device(model, state_dict, device, dtype=None): |
| | |
| | missing_keys = list(model.state_dict().keys() - state_dict.keys()) |
| | unexpected_keys = list(state_dict.keys() - model.state_dict().keys()) |
| |
|
| | |
| | if not missing_keys and not unexpected_keys: |
| | for k in list(state_dict.keys()): |
| | set_module_tensor_to_device(model, k, device, value=state_dict.pop(k), dtype=dtype) |
| | return "<All keys matched successfully>" |
| |
|
| | |
| | error_msgs: List[str] = [] |
| | if missing_keys: |
| | error_msgs.insert(0, "Missing key(s) in state_dict: {}. ".format(", ".join('"{}"'.format(k) for k in missing_keys))) |
| | if unexpected_keys: |
| | error_msgs.insert(0, "Unexpected key(s) in state_dict: {}. ".format(", ".join('"{}"'.format(k) for k in unexpected_keys))) |
| |
|
| | raise RuntimeError("Error(s) in loading state_dict for {}:\n\t{}".format(model.__class__.__name__, "\n\t".join(error_msgs))) |
| |
|
| |
|
| | def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dtype=None, disable_mmap=False): |
| | |
| | |
| |
|
| | |
| | if model_util.is_safetensors(ckpt_path): |
| | checkpoint = None |
| | if disable_mmap: |
| | state_dict = safetensors.torch.load(open(ckpt_path, "rb").read()) |
| | else: |
| | try: |
| | state_dict = load_file(ckpt_path, device=map_location) |
| | except: |
| | state_dict = load_file(ckpt_path) |
| | epoch = None |
| | global_step = None |
| | else: |
| | checkpoint = torch.load(ckpt_path, map_location=map_location) |
| | if "state_dict" in checkpoint: |
| | state_dict = checkpoint["state_dict"] |
| | epoch = checkpoint.get("epoch", 0) |
| | global_step = checkpoint.get("global_step", 0) |
| | else: |
| | state_dict = checkpoint |
| | epoch = 0 |
| | global_step = 0 |
| | checkpoint = None |
| |
|
| | |
| | logger.info("building U-Net") |
| | with init_empty_weights(): |
| | unet = sdxl_original_unet.SdxlUNet2DConditionModel() |
| |
|
| | logger.info("loading U-Net from checkpoint") |
| | unet_sd = {} |
| | for k in list(state_dict.keys()): |
| | if k.startswith("model.diffusion_model."): |
| | unet_sd[k.replace("model.diffusion_model.", "")] = state_dict.pop(k) |
| | info = _load_state_dict_on_device(unet, unet_sd, device=map_location, dtype=dtype) |
| | logger.info(f"U-Net: {info}") |
| |
|
| | |
| | logger.info("building text encoders") |
| |
|
| | |
| | text_model1_cfg = CLIPTextConfig( |
| | vocab_size=49408, |
| | hidden_size=768, |
| | intermediate_size=3072, |
| | num_hidden_layers=12, |
| | num_attention_heads=12, |
| | max_position_embeddings=77, |
| | hidden_act="quick_gelu", |
| | layer_norm_eps=1e-05, |
| | dropout=0.0, |
| | attention_dropout=0.0, |
| | initializer_range=0.02, |
| | initializer_factor=1.0, |
| | pad_token_id=1, |
| | bos_token_id=0, |
| | eos_token_id=2, |
| | model_type="clip_text_model", |
| | projection_dim=768, |
| | |
| | |
| | ) |
| | with init_empty_weights(): |
| | text_model1 = CLIPTextModel._from_config(text_model1_cfg) |
| |
|
| | |
| | |
| | text_model2_cfg = CLIPTextConfig( |
| | vocab_size=49408, |
| | hidden_size=1280, |
| | intermediate_size=5120, |
| | num_hidden_layers=32, |
| | num_attention_heads=20, |
| | max_position_embeddings=77, |
| | hidden_act="gelu", |
| | layer_norm_eps=1e-05, |
| | dropout=0.0, |
| | attention_dropout=0.0, |
| | initializer_range=0.02, |
| | initializer_factor=1.0, |
| | pad_token_id=1, |
| | bos_token_id=0, |
| | eos_token_id=2, |
| | model_type="clip_text_model", |
| | projection_dim=1280, |
| | |
| | |
| | ) |
| | with init_empty_weights(): |
| | text_model2 = CLIPTextModelWithProjection(text_model2_cfg) |
| |
|
| | logger.info("loading text encoders from checkpoint") |
| | te1_sd = {} |
| | te2_sd = {} |
| | for k in list(state_dict.keys()): |
| | if k.startswith("conditioner.embedders.0.transformer."): |
| | te1_sd[k.replace("conditioner.embedders.0.transformer.", "")] = state_dict.pop(k) |
| | elif k.startswith("conditioner.embedders.1.model."): |
| | te2_sd[k] = state_dict.pop(k) |
| |
|
| | |
| | if "text_model.embeddings.position_ids" in te1_sd: |
| | te1_sd.pop("text_model.embeddings.position_ids") |
| |
|
| | info1 = _load_state_dict_on_device(text_model1, te1_sd, device=map_location) |
| | logger.info(f"text encoder 1: {info1}") |
| |
|
| | converted_sd, logit_scale = convert_sdxl_text_encoder_2_checkpoint(te2_sd, max_length=77) |
| | info2 = _load_state_dict_on_device(text_model2, converted_sd, device=map_location) |
| | logger.info(f"text encoder 2: {info2}") |
| |
|
| | |
| | logger.info("building VAE") |
| | vae_config = model_util.create_vae_diffusers_config() |
| | with init_empty_weights(): |
| | vae = AutoencoderKL(**vae_config) |
| |
|
| | logger.info("loading VAE from checkpoint") |
| | converted_vae_checkpoint = model_util.convert_ldm_vae_checkpoint(state_dict, vae_config) |
| | info = _load_state_dict_on_device(vae, converted_vae_checkpoint, device=map_location, dtype=dtype) |
| | logger.info(f"VAE: {info}") |
| |
|
| | ckpt_info = (epoch, global_step) if epoch is not None else None |
| | return text_model1, text_model2, vae, unet, logit_scale, ckpt_info |
| |
|
| |
|
| | def make_unet_conversion_map(): |
| | unet_conversion_map_layer = [] |
| |
|
| | for i in range(3): |
| | |
| | for j in range(2): |
| | |
| | hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}." |
| | sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0." |
| | unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix)) |
| |
|
| | if i < 3: |
| | |
| | hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}." |
| | sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1." |
| | unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix)) |
| |
|
| | for j in range(3): |
| | |
| | hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}." |
| | sd_up_res_prefix = f"output_blocks.{3*i + j}.0." |
| | unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix)) |
| |
|
| | |
| | |
| | hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}." |
| | sd_up_atn_prefix = f"output_blocks.{3*i + j}.1." |
| | unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix)) |
| |
|
| | if i < 3: |
| | |
| | hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv." |
| | sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op." |
| | unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix)) |
| |
|
| | |
| | hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0." |
| | sd_upsample_prefix = f"output_blocks.{3*i + 2}.{2}." |
| | unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix)) |
| |
|
| | hf_mid_atn_prefix = "mid_block.attentions.0." |
| | sd_mid_atn_prefix = "middle_block.1." |
| | unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix)) |
| |
|
| | for j in range(2): |
| | hf_mid_res_prefix = f"mid_block.resnets.{j}." |
| | sd_mid_res_prefix = f"middle_block.{2*j}." |
| | unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix)) |
| |
|
| | unet_conversion_map_resnet = [ |
| | |
| | ("in_layers.0.", "norm1."), |
| | ("in_layers.2.", "conv1."), |
| | ("out_layers.0.", "norm2."), |
| | ("out_layers.3.", "conv2."), |
| | ("emb_layers.1.", "time_emb_proj."), |
| | ("skip_connection.", "conv_shortcut."), |
| | ] |
| |
|
| | unet_conversion_map = [] |
| | for sd, hf in unet_conversion_map_layer: |
| | if "resnets" in hf: |
| | for sd_res, hf_res in unet_conversion_map_resnet: |
| | unet_conversion_map.append((sd + sd_res, hf + hf_res)) |
| | else: |
| | unet_conversion_map.append((sd, hf)) |
| |
|
| | for j in range(2): |
| | hf_time_embed_prefix = f"time_embedding.linear_{j+1}." |
| | sd_time_embed_prefix = f"time_embed.{j*2}." |
| | unet_conversion_map.append((sd_time_embed_prefix, hf_time_embed_prefix)) |
| |
|
| | for j in range(2): |
| | hf_label_embed_prefix = f"add_embedding.linear_{j+1}." |
| | sd_label_embed_prefix = f"label_emb.0.{j*2}." |
| | unet_conversion_map.append((sd_label_embed_prefix, hf_label_embed_prefix)) |
| |
|
| | unet_conversion_map.append(("input_blocks.0.0.", "conv_in.")) |
| | unet_conversion_map.append(("out.0.", "conv_norm_out.")) |
| | unet_conversion_map.append(("out.2.", "conv_out.")) |
| |
|
| | return unet_conversion_map |
| |
|
| |
|
| | def convert_diffusers_unet_state_dict_to_sdxl(du_sd): |
| | unet_conversion_map = make_unet_conversion_map() |
| |
|
| | conversion_map = {hf: sd for sd, hf in unet_conversion_map} |
| | return convert_unet_state_dict(du_sd, conversion_map) |
| |
|
| |
|
| | def convert_unet_state_dict(src_sd, conversion_map): |
| | converted_sd = {} |
| | for src_key, value in src_sd.items(): |
| | |
| | src_key_fragments = src_key.split(".")[:-1] |
| | while len(src_key_fragments) > 0: |
| | src_key_prefix = ".".join(src_key_fragments) + "." |
| | if src_key_prefix in conversion_map: |
| | converted_prefix = conversion_map[src_key_prefix] |
| | converted_key = converted_prefix + src_key[len(src_key_prefix) :] |
| | converted_sd[converted_key] = value |
| | break |
| | src_key_fragments.pop(-1) |
| | assert len(src_key_fragments) > 0, f"key {src_key} not found in conversion map" |
| |
|
| | return converted_sd |
| |
|
| |
|
| | def convert_sdxl_unet_state_dict_to_diffusers(sd): |
| | unet_conversion_map = make_unet_conversion_map() |
| |
|
| | conversion_dict = {sd: hf for sd, hf in unet_conversion_map} |
| | return convert_unet_state_dict(sd, conversion_dict) |
| |
|
| |
|
| | def convert_text_encoder_2_state_dict_to_sdxl(checkpoint, logit_scale): |
| | def convert_key(key): |
| | |
| | if ".position_ids" in key: |
| | return None |
| |
|
| | |
| | key = key.replace("text_model.encoder.", "transformer.") |
| | key = key.replace("text_model.", "") |
| | if "layers" in key: |
| | |
| | key = key.replace(".layers.", ".resblocks.") |
| | if ".layer_norm" in key: |
| | key = key.replace(".layer_norm", ".ln_") |
| | elif ".mlp." in key: |
| | key = key.replace(".fc1.", ".c_fc.") |
| | key = key.replace(".fc2.", ".c_proj.") |
| | elif ".self_attn.out_proj" in key: |
| | key = key.replace(".self_attn.out_proj.", ".attn.out_proj.") |
| | elif ".self_attn." in key: |
| | key = None |
| | else: |
| | raise ValueError(f"unexpected key in DiffUsers model: {key}") |
| | elif ".position_embedding" in key: |
| | key = key.replace("embeddings.position_embedding.weight", "positional_embedding") |
| | elif ".token_embedding" in key: |
| | key = key.replace("embeddings.token_embedding.weight", "token_embedding.weight") |
| | elif "text_projection" in key: |
| | key = key.replace("text_projection.weight", "text_projection") |
| | elif "final_layer_norm" in key: |
| | key = key.replace("final_layer_norm", "ln_final") |
| | return key |
| |
|
| | keys = list(checkpoint.keys()) |
| | new_sd = {} |
| | for key in keys: |
| | new_key = convert_key(key) |
| | if new_key is None: |
| | continue |
| | new_sd[new_key] = checkpoint[key] |
| |
|
| | |
| | for key in keys: |
| | if "layers" in key and "q_proj" in key: |
| | |
| | key_q = key |
| | key_k = key.replace("q_proj", "k_proj") |
| | key_v = key.replace("q_proj", "v_proj") |
| |
|
| | value_q = checkpoint[key_q] |
| | value_k = checkpoint[key_k] |
| | value_v = checkpoint[key_v] |
| | value = torch.cat([value_q, value_k, value_v]) |
| |
|
| | new_key = key.replace("text_model.encoder.layers.", "transformer.resblocks.") |
| | new_key = new_key.replace(".self_attn.q_proj.", ".attn.in_proj_") |
| | new_sd[new_key] = value |
| |
|
| | if logit_scale is not None: |
| | new_sd["logit_scale"] = logit_scale |
| |
|
| | return new_sd |
| |
|
| |
|
| | def save_stable_diffusion_checkpoint( |
| | output_file, |
| | text_encoder1, |
| | text_encoder2, |
| | unet, |
| | epochs, |
| | steps, |
| | ckpt_info, |
| | vae, |
| | logit_scale, |
| | metadata, |
| | save_dtype=None, |
| | ): |
| | state_dict = {} |
| |
|
| | def update_sd(prefix, sd): |
| | for k, v in sd.items(): |
| | key = prefix + k |
| | if save_dtype is not None: |
| | v = v.detach().clone().to("cpu").to(save_dtype) |
| | state_dict[key] = v |
| |
|
| | |
| | update_sd("model.diffusion_model.", unet.state_dict()) |
| |
|
| | |
| | update_sd("conditioner.embedders.0.transformer.", text_encoder1.state_dict()) |
| |
|
| | text_enc2_dict = convert_text_encoder_2_state_dict_to_sdxl(text_encoder2.state_dict(), logit_scale) |
| | update_sd("conditioner.embedders.1.model.", text_enc2_dict) |
| |
|
| | |
| | vae_dict = model_util.convert_vae_state_dict(vae.state_dict()) |
| | update_sd("first_stage_model.", vae_dict) |
| |
|
| | |
| | key_count = len(state_dict.keys()) |
| | new_ckpt = {"state_dict": state_dict} |
| |
|
| | |
| | if ckpt_info is not None: |
| | epochs += ckpt_info[0] |
| | steps += ckpt_info[1] |
| |
|
| | new_ckpt["epoch"] = epochs |
| | new_ckpt["global_step"] = steps |
| |
|
| | if model_util.is_safetensors(output_file): |
| | save_file(state_dict, output_file, metadata) |
| | else: |
| | torch.save(new_ckpt, output_file) |
| |
|
| | return key_count |
| |
|
| |
|
| | def save_diffusers_checkpoint( |
| | output_dir, text_encoder1, text_encoder2, unet, pretrained_model_name_or_path, vae=None, use_safetensors=False, save_dtype=None |
| | ): |
| | from diffusers import StableDiffusionXLPipeline |
| |
|
| | |
| | unet_sd = unet.state_dict() |
| | du_unet_sd = convert_sdxl_unet_state_dict_to_diffusers(unet_sd) |
| |
|
| | diffusers_unet = UNet2DConditionModel(**DIFFUSERS_SDXL_UNET_CONFIG) |
| | if save_dtype is not None: |
| | diffusers_unet.to(save_dtype) |
| | diffusers_unet.load_state_dict(du_unet_sd) |
| |
|
| | |
| | if pretrained_model_name_or_path is None: |
| | pretrained_model_name_or_path = DIFFUSERS_REF_MODEL_ID_SDXL |
| |
|
| | scheduler = EulerDiscreteScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler") |
| | tokenizer1 = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer") |
| | tokenizer2 = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer_2") |
| | if vae is None: |
| | vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae") |
| |
|
| | |
| | def remove_name_or_path(model): |
| | if hasattr(model, "config"): |
| | model.config._name_or_path = None |
| | model.config._name_or_path = None |
| |
|
| | remove_name_or_path(diffusers_unet) |
| | remove_name_or_path(text_encoder1) |
| | remove_name_or_path(text_encoder2) |
| | remove_name_or_path(scheduler) |
| | remove_name_or_path(tokenizer1) |
| | remove_name_or_path(tokenizer2) |
| | remove_name_or_path(vae) |
| |
|
| | pipeline = StableDiffusionXLPipeline( |
| | unet=diffusers_unet, |
| | text_encoder=text_encoder1, |
| | text_encoder_2=text_encoder2, |
| | vae=vae, |
| | scheduler=scheduler, |
| | tokenizer=tokenizer1, |
| | tokenizer_2=tokenizer2, |
| | ) |
| | if save_dtype is not None: |
| | pipeline.to(None, save_dtype) |
| | pipeline.save_pretrained(output_dir, safe_serialization=use_safetensors) |
| |
|