studio-ai / src /lora_helper.py
adarshnagrikar14
feat: add error handling and prompt enhancement in image generation
51246d7
raw
history blame
15.5 kB
from diffusers.models.attention_processor import FluxAttnProcessor2_0
from safetensors import safe_open
import re
import torch
from .layers_cache import MultiDoubleStreamBlockLoraProcessor, MultiSingleStreamBlockLoraProcessor
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
def load_safetensors(path):
tensors = {}
with safe_open(path, framework="pt", device="cpu") as f:
for key in f.keys():
tensors[key] = f.get_tensor(key)
return tensors
def get_lora_rank(checkpoint):
for k in checkpoint.keys():
if k.endswith(".down.weight"):
return checkpoint[k].shape[0]
return None # Return None if rank cannot be determined
def load_checkpoint(local_path):
if local_path is not None:
if '.safetensors' in local_path:
print(f"Loading .safetensors checkpoint from {local_path}")
checkpoint = load_safetensors(local_path)
else:
print(f"Loading checkpoint from {local_path}")
checkpoint = torch.load(local_path, map_location='cpu')
return checkpoint
def update_model_with_lora(checkpoint, lora_weights, transformer, cond_size):
try:
number = len(lora_weights)
ranks = [get_lora_rank(checkpoint) for _ in range(number)]
# Check if ranks could be determined
if None in ranks or not ranks:
print("Warning: Could not determine LoRA rank from checkpoint, using default rank of 4")
ranks = [4 for _ in range(number)]
lora_attn_procs = {}
double_blocks_idx = list(range(19))
single_blocks_idx = list(range(38))
for name, attn_processor in transformer.attn_processors.items():
match = re.search(r'\.(\d+)\.', name)
if match:
layer_index = int(match.group(1))
else:
layer_index = -1 # Use a default value if no match
if name.startswith("transformer_blocks") and layer_index in double_blocks_idx:
lora_state_dicts = {}
for key, value in checkpoint.items():
# Match based on the layer index in the key (assuming the key contains layer index)
if re.search(r'\.(\d+)\.', key):
checkpoint_layer_index = int(re.search(r'\.(\d+)\.', key).group(1))
if checkpoint_layer_index == layer_index and key.startswith("transformer_blocks"):
lora_state_dicts[key] = value
# Create processor without device/dtype initially
lora_attn_procs[name] = MultiDoubleStreamBlockLoraProcessor(
dim=3072, ranks=ranks, network_alphas=ranks, lora_weights=lora_weights,
device=None, dtype=None, cond_width=cond_size, cond_height=cond_size, n_loras=number
)
# Load the weights from the checkpoint dictionary into the corresponding layers
for n in range(number):
# Load q_loras weights
down_weight = lora_state_dicts.get(f'{name}.q_loras.{n}.down.weight', None)
if down_weight is not None:
lora_attn_procs[name].q_loras[n].down.weight.data.copy_(down_weight)
up_weight = lora_state_dicts.get(f'{name}.q_loras.{n}.up.weight', None)
if up_weight is not None:
lora_attn_procs[name].q_loras[n].up.weight.data.copy_(up_weight)
# Load k_loras weights
down_weight = lora_state_dicts.get(f'{name}.k_loras.{n}.down.weight', None)
if down_weight is not None:
lora_attn_procs[name].k_loras[n].down.weight.data.copy_(down_weight)
up_weight = lora_state_dicts.get(f'{name}.k_loras.{n}.up.weight', None)
if up_weight is not None:
lora_attn_procs[name].k_loras[n].up.weight.data.copy_(up_weight)
# Load v_loras weights
down_weight = lora_state_dicts.get(f'{name}.v_loras.{n}.down.weight', None)
if down_weight is not None:
lora_attn_procs[name].v_loras[n].down.weight.data.copy_(down_weight)
up_weight = lora_state_dicts.get(f'{name}.v_loras.{n}.up.weight', None)
if up_weight is not None:
lora_attn_procs[name].v_loras[n].up.weight.data.copy_(up_weight)
# Load proj_loras weights
down_weight = lora_state_dicts.get(f'{name}.proj_loras.{n}.down.weight', None)
if down_weight is not None:
lora_attn_procs[name].proj_loras[n].down.weight.data.copy_(down_weight)
up_weight = lora_state_dicts.get(f'{name}.proj_loras.{n}.up.weight', None)
if up_weight is not None:
lora_attn_procs[name].proj_loras[n].up.weight.data.copy_(up_weight)
# Move the processor to the device after loading weights
lora_attn_procs[name] = lora_attn_procs[name].to(device)
elif name.startswith("single_transformer_blocks") and layer_index in single_blocks_idx:
lora_state_dicts = {}
for key, value in checkpoint.items():
# Match based on the layer index in the key (assuming the key contains layer index)
if re.search(r'\.(\d+)\.', key):
checkpoint_layer_index = int(re.search(r'\.(\d+)\.', key).group(1))
if checkpoint_layer_index == layer_index and key.startswith("single_transformer_blocks"):
lora_state_dicts[key] = value
# Create processor without device/dtype initially
lora_attn_procs[name] = MultiSingleStreamBlockLoraProcessor(
dim=3072, ranks=ranks, network_alphas=ranks, lora_weights=lora_weights,
device=None, dtype=None, cond_width=cond_size, cond_height=cond_size, n_loras=number
)
# Load the weights from the checkpoint dictionary into the corresponding layers
for n in range(number):
# Load q_loras weights
down_weight = lora_state_dicts.get(f'{name}.q_loras.{n}.down.weight', None)
if down_weight is not None:
lora_attn_procs[name].q_loras[n].down.weight.data.copy_(down_weight)
up_weight = lora_state_dicts.get(f'{name}.q_loras.{n}.up.weight', None)
if up_weight is not None:
lora_attn_procs[name].q_loras[n].up.weight.data.copy_(up_weight)
# Load k_loras weights
down_weight = lora_state_dicts.get(f'{name}.k_loras.{n}.down.weight', None)
if down_weight is not None:
lora_attn_procs[name].k_loras[n].down.weight.data.copy_(down_weight)
up_weight = lora_state_dicts.get(f'{name}.k_loras.{n}.up.weight', None)
if up_weight is not None:
lora_attn_procs[name].k_loras[n].up.weight.data.copy_(up_weight)
# Load v_loras weights
down_weight = lora_state_dicts.get(f'{name}.v_loras.{n}.down.weight', None)
if down_weight is not None:
lora_attn_procs[name].v_loras[n].down.weight.data.copy_(down_weight)
up_weight = lora_state_dicts.get(f'{name}.v_loras.{n}.up.weight', None)
if up_weight is not None:
lora_attn_procs[name].v_loras[n].up.weight.data.copy_(up_weight)
# Move the processor to the device after loading weights
lora_attn_procs[name] = lora_attn_procs[name].to(device)
else:
lora_attn_procs[name] = FluxAttnProcessor2_0()
transformer.set_attn_processor(lora_attn_procs)
except Exception as e:
print(f"Error in update_model_with_lora: {str(e)}")
# Fallback to default processor to prevent breaking the pipeline
lora_attn_procs = {}
for name in transformer.attn_processors.keys():
lora_attn_procs[name] = FluxAttnProcessor2_0()
transformer.set_attn_processor(lora_attn_procs)
raise
def update_model_with_multi_lora(checkpoints, lora_weights, transformer, cond_size):
ck_number = len(checkpoints)
cond_lora_number = [len(ls) for ls in lora_weights]
cond_number = sum(cond_lora_number)
ranks = [get_lora_rank(checkpoint) for checkpoint in checkpoints]
multi_lora_weight = []
for ls in lora_weights:
for n in ls:
multi_lora_weight.append(n)
lora_attn_procs = {}
double_blocks_idx = list(range(19))
single_blocks_idx = list(range(38))
for name, attn_processor in transformer.attn_processors.items():
match = re.search(r'\.(\d+)\.', name)
if match:
layer_index = int(match.group(1))
if name.startswith("transformer_blocks") and layer_index in double_blocks_idx:
lora_state_dicts = [{} for _ in range(ck_number)]
for idx, checkpoint in enumerate(checkpoints):
for key, value in checkpoint.items():
# Match based on the layer index in the key (assuming the key contains layer index)
if re.search(r'\.(\d+)\.', key):
checkpoint_layer_index = int(re.search(r'\.(\d+)\.', key).group(1))
if checkpoint_layer_index == layer_index and key.startswith("transformer_blocks"):
lora_state_dicts[idx][key] = value
lora_attn_procs[name] = MultiDoubleStreamBlockLoraProcessor(
dim=3072, ranks=ranks, network_alphas=ranks, lora_weights=multi_lora_weight, device=device, dtype=torch.bfloat16, cond_width=cond_size, cond_height=cond_size, n_loras=cond_number
)
# Load the weights from the checkpoint dictionary into the corresponding layers
num = 0
for idx in range(ck_number):
for n in range(cond_lora_number[idx]):
lora_attn_procs[name].q_loras[num].down.weight.data = lora_state_dicts[idx].get(f'{name}.q_loras.{n}.down.weight', None)
lora_attn_procs[name].q_loras[num].up.weight.data = lora_state_dicts[idx].get(f'{name}.q_loras.{n}.up.weight', None)
lora_attn_procs[name].k_loras[num].down.weight.data = lora_state_dicts[idx].get(f'{name}.k_loras.{n}.down.weight', None)
lora_attn_procs[name].k_loras[num].up.weight.data = lora_state_dicts[idx].get(f'{name}.k_loras.{n}.up.weight', None)
lora_attn_procs[name].v_loras[num].down.weight.data = lora_state_dicts[idx].get(f'{name}.v_loras.{n}.down.weight', None)
lora_attn_procs[name].v_loras[num].up.weight.data = lora_state_dicts[idx].get(f'{name}.v_loras.{n}.up.weight', None)
lora_attn_procs[name].proj_loras[num].down.weight.data = lora_state_dicts[idx].get(f'{name}.proj_loras.{n}.down.weight', None)
lora_attn_procs[name].proj_loras[num].up.weight.data = lora_state_dicts[idx].get(f'{name}.proj_loras.{n}.up.weight', None)
lora_attn_procs[name].to(device)
num += 1
elif name.startswith("single_transformer_blocks") and layer_index in single_blocks_idx:
lora_state_dicts = [{} for _ in range(ck_number)]
for idx, checkpoint in enumerate(checkpoints):
for key, value in checkpoint.items():
# Match based on the layer index in the key (assuming the key contains layer index)
if re.search(r'\.(\d+)\.', key):
checkpoint_layer_index = int(re.search(r'\.(\d+)\.', key).group(1))
if checkpoint_layer_index == layer_index and key.startswith("single_transformer_blocks"):
lora_state_dicts[idx][key] = value
lora_attn_procs[name] = MultiSingleStreamBlockLoraProcessor(
dim=3072, ranks=ranks, network_alphas=ranks, lora_weights=multi_lora_weight, device=device, dtype=torch.bfloat16, cond_width=cond_size, cond_height=cond_size, n_loras=cond_number
)
# Load the weights from the checkpoint dictionary into the corresponding layers
num = 0
for idx in range(ck_number):
for n in range(cond_lora_number[idx]):
lora_attn_procs[name].q_loras[num].down.weight.data = lora_state_dicts[idx].get(f'{name}.q_loras.{n}.down.weight', None)
lora_attn_procs[name].q_loras[num].up.weight.data = lora_state_dicts[idx].get(f'{name}.q_loras.{n}.up.weight', None)
lora_attn_procs[name].k_loras[num].down.weight.data = lora_state_dicts[idx].get(f'{name}.k_loras.{n}.down.weight', None)
lora_attn_procs[name].k_loras[num].up.weight.data = lora_state_dicts[idx].get(f'{name}.k_loras.{n}.up.weight', None)
lora_attn_procs[name].v_loras[num].down.weight.data = lora_state_dicts[idx].get(f'{name}.v_loras.{n}.down.weight', None)
lora_attn_procs[name].v_loras[num].up.weight.data = lora_state_dicts[idx].get(f'{name}.v_loras.{n}.up.weight', None)
lora_attn_procs[name].to(device)
num += 1
else:
lora_attn_procs[name] = FluxAttnProcessor2_0()
transformer.set_attn_processor(lora_attn_procs)
def set_single_lora(transformer, local_path, lora_weights=[], cond_size=512):
checkpoint = load_checkpoint(local_path)
update_model_with_lora(checkpoint, lora_weights, transformer, cond_size)
def set_multi_lora(transformer, local_paths, lora_weights=[[]], cond_size=512):
checkpoints = [load_checkpoint(local_path) for local_path in local_paths]
update_model_with_multi_lora(checkpoints, lora_weights, transformer, cond_size)
def unset_lora(transformer):
lora_attn_procs = {}
for name, attn_processor in transformer.attn_processors.items():
lora_attn_procs[name] = FluxAttnProcessor2_0()
transformer.set_attn_processor(lora_attn_procs)
'''
unset_lora(pipe.transformer)
lora_path = "./lora.safetensors"
lora_weights = [1, 1]
set_lora(pipe.transformer, local_path=lora_path, lora_weights=lora_weights, cond_size=512)
'''