| from typing import List, Dict |
| import re |
| import torch |
|
|
| from modules import extra_networks, shared |
|
|
| re_AND = re.compile(r"\bAND\b") |
|
|
|
|
| def load_prompt_loras(prompt: str): |
| prompt_loras.clear() |
| subprompts = re_AND.split(prompt) |
| tmp_prompt_loras = [] |
| for i, subprompt in enumerate(subprompts): |
| loras = {} |
| _, extra_network_data = extra_networks.parse_prompt(subprompt) |
| for params in extra_network_data['lora']: |
| name = params.items[0] |
| multiplier = float(params.items[1]) if len(params.items) > 1 else 1.0 |
| loras[name] = multiplier |
|
|
| tmp_prompt_loras.append(loras) |
| prompt_loras.extend(tmp_prompt_loras * num_batches) |
|
|
|
|
| def reset_counters(): |
| global text_model_encoder_counter |
| global diffusion_model_counter |
|
|
| |
| text_model_encoder_counter = -1 |
| diffusion_model_counter = 0 |
|
|
|
|
| def lora_forward(compvis_module, input, res): |
| global text_model_encoder_counter |
| global diffusion_model_counter |
|
|
| import lora |
|
|
| if len(lora.loaded_loras) == 0: |
| return res |
|
|
| lora_layer_name: str | None = getattr(compvis_module, 'lora_layer_name', None) |
| if lora_layer_name is None: |
| return res |
|
|
| num_loras = len(lora.loaded_loras) |
| if text_model_encoder_counter == -1: |
| text_model_encoder_counter = len(prompt_loras) * num_loras |
|
|
| |
|
|
| for lora in lora.loaded_loras: |
| module = lora.modules.get(lora_layer_name, None) |
| if module is None: |
| continue |
|
|
| if shared.opts.lora_apply_to_outputs and res.shape == input.shape: |
| patch = module.up(module.down(res)) |
| else: |
| patch = module.up(module.down(input)) |
|
|
| alpha = module.alpha / module.up.weight.shape[1] if module.alpha else 1.0 |
|
|
| num_prompts = len(prompt_loras) |
|
|
| |
|
|
| if enabled: |
| if lora_layer_name.startswith("transformer_"): |
| |
| if 0 <= text_model_encoder_counter // num_loras < len(prompt_loras): |
| |
| loras = prompt_loras[text_model_encoder_counter // num_loras] |
| multiplier = loras.get(lora.name, 0.0) |
| if multiplier != 0.0: |
| |
| res += multiplier * alpha * patch |
| else: |
| |
| if opt_uc_text_model_encoder and lora.multiplier != 0.0: |
| |
| res += lora.multiplier * alpha * patch |
|
|
| if lora_layer_name.endswith("_11_mlp_fc2"): |
| text_model_encoder_counter += 1 |
| |
| if text_model_encoder_counter == (len(prompt_loras) + num_batches) * num_loras: |
| text_model_encoder_counter = 0 |
|
|
| elif lora_layer_name.startswith("diffusion_model_"): |
|
|
| if res.shape[0] == num_batches * num_prompts + num_batches: |
| |
| tensor_off = 0 |
| uncond_off = num_batches * num_prompts |
| for b in range(num_batches): |
| |
| for p, loras in enumerate(prompt_loras): |
| multiplier = loras.get(lora.name, 0.0) |
| if multiplier != 0.0: |
| |
| res[tensor_off] += multiplier * alpha * patch[tensor_off] |
| tensor_off += 1 |
|
|
| |
| if opt_uc_diffusion_model and lora.multiplier != 0.0: |
| |
| res[uncond_off] += lora.multiplier * alpha * patch[uncond_off] |
| uncond_off += 1 |
| else: |
| |
| cur_num_prompts = res.shape[0] |
| base = (diffusion_model_counter // cur_num_prompts) // num_loras * cur_num_prompts |
| if 0 <= base < len(prompt_loras): |
| |
| for off in range(cur_num_prompts): |
| loras = prompt_loras[base + off] |
| multiplier = loras.get(lora.name, 0.0) |
| if multiplier != 0.0: |
| |
| res[off] += multiplier * alpha * patch[off] |
| else: |
| |
| if opt_uc_diffusion_model and lora.multiplier != 0.0: |
| |
| res += lora.multiplier * alpha * patch |
|
|
| if lora_layer_name.endswith("_11_1_proj_out"): |
| diffusion_model_counter += cur_num_prompts |
| |
| if diffusion_model_counter >= (len(prompt_loras) + num_batches) * num_loras: |
| diffusion_model_counter = 0 |
| else: |
| |
| if lora.multiplier != 0.0: |
| |
| res += lora.multiplier * alpha * patch |
| else: |
| |
| if lora.multiplier != 0.0: |
| |
| res += lora.multiplier * alpha * patch |
|
|
| return res |
|
|
|
|
| def lora_Linear_forward(self, input): |
| return lora_forward(self, input, torch.nn.Linear_forward_before_lora(self, input)) |
|
|
|
|
| def lora_Conv2d_forward(self, input): |
| return lora_forward(self, input, torch.nn.Conv2d_forward_before_lora(self, input)) |
|
|
|
|
| enabled = False |
| opt_uc_text_model_encoder = False |
| opt_uc_diffusion_model = False |
| verbose = True |
|
|
| num_batches: int = 0 |
| prompt_loras: List[Dict[str, float]] = [] |
| text_model_encoder_counter: int = -1 |
| diffusion_model_counter: int = 0 |
|
|