File size: 2,463 Bytes
1c8e113 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 | """
This script utilizes code from ControlNet available at:
https://github.com/lllyasviel/ControlNet/blob/main/tool_add_control.py
Original Author: Lvmin Zhang
License: Apache License 2.0
"""
import sys
import os
os.environ['HF_HOME'] = '/tmp'
# assert len(sys.argv) == 3, 'Args are wrong.'
# input_path = sys.argv[1]
# output_path = sys.argv[2]
import torch
from oldm.hack import disable_verbosity
disable_verbosity()
from oldm.model import create_model
from hra import inject_trainable_hra
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--input_path', type=str, default='./models/v1-5-pruned.ckpt')
parser.add_argument('--output_path', type=str, default='./models/hra_half_init_l_8.ckpt')
parser.add_argument('--r', type=int, default=8)
parser.add_argument('--apply_GS', action='store_true', default=False)
args = parser.parse_args()
# args.output_path = f'./models/hra_none_l_8.ckpt'
assert os.path.exists(args.input_path), 'Input model does not exist.'
# assert not os.path.exists(output_path), 'Output filename already exists.'
assert os.path.exists(os.path.dirname(args.output_path)), 'Output path is not valid.'
def get_node_name(name, parent_name):
if len(name) <= len(parent_name):
return False, ''
p = name[:len(parent_name)]
if p != parent_name:
return False, ''
return True, name[len(parent_name):]
model = create_model(config_path='./configs/oft_ldm_v15.yaml')
model.model.requires_grad_(False)
unet_lora_params, train_names = inject_trainable_hra(model.model, r=args.r, apply_GS=args.apply_GS)
pretrained_weights = torch.load(args.input_path)
if 'state_dict' in pretrained_weights:
pretrained_weights = pretrained_weights['state_dict']
scratch_dict = model.state_dict()
target_dict = {}
names = []
for k in scratch_dict.keys():
names.append(k)
if k in pretrained_weights:
target_dict[k] = pretrained_weights[k].clone()
else:
if 'fixed_linear.' in k:
copy_k = k.replace('fixed_linear.', '')
target_dict[k] = pretrained_weights[copy_k].clone()
else:
target_dict[k] = scratch_dict[k].clone()
print(f'These weights are newly added: {k}')
with open('HRA_model_names.txt', 'w') as file:
for element in names:
file.write(element + '\n')
model.load_state_dict(target_dict, strict=True)
torch.save(model.state_dict(), args.output_path)
# print('没有保存模型')
print('Done.')
|