Spaces:
Runtime error
Runtime error
| import torch | |
| from contextlib import contextmanager | |
| from typing import Union, Tuple | |
| _size_2_t = Union[int, Tuple[int, int]] | |
| class LinearWithLoRA(torch.nn.Module): | |
| def __init__( | |
| self, | |
| in_features: int, | |
| out_features: int, | |
| bias: bool = True, | |
| device=None, | |
| dtype=None) -> None: | |
| super().__init__() | |
| self.weight_module = None | |
| self.up = None | |
| self.down = None | |
| self.bias = None | |
| self.in_features = in_features | |
| self.out_features = out_features | |
| self.device = device | |
| self.dtype = dtype | |
| self.weight = None | |
| def bind_lora(self, weight_module): | |
| self.weight_module = [weight_module] | |
| def unbind_lora(self): | |
| if self.up is not None and self.down is not None: # SAI's model is weird and needs this | |
| self.weight_module = None | |
| def get_original_weight(self): | |
| if self.weight_module is None: | |
| return None | |
| return self.weight_module[0].weight | |
| def forward(self, x): | |
| if self.weight is not None: | |
| return torch.nn.functional.linear(x, self.weight.to(x), | |
| self.bias.to(x) if self.bias is not None else None) | |
| original_weight = self.get_original_weight() | |
| if original_weight is None: | |
| return None # A1111 needs first_time_calculation | |
| if self.up is not None and self.down is not None: | |
| weight = original_weight.to(x) + torch.mm(self.up, self.down).to(x) | |
| else: | |
| weight = original_weight.to(x) | |
| return torch.nn.functional.linear(x, weight, self.bias.to(x) if self.bias is not None else None) | |
| class Conv2dWithLoRA(torch.nn.Module): | |
| def __init__( | |
| self, | |
| in_channels: int, | |
| out_channels: int, | |
| kernel_size: _size_2_t, | |
| stride: _size_2_t = 1, | |
| padding: Union[str, _size_2_t] = 0, | |
| dilation: _size_2_t = 1, | |
| groups: int = 1, | |
| bias: bool = True, | |
| padding_mode: str = 'zeros', | |
| device=None, | |
| dtype=None | |
| ) -> None: | |
| super().__init__() | |
| self.stride = stride | |
| self.padding = padding | |
| self.dilation = dilation | |
| self.groups = groups | |
| self.weight_module = None | |
| self.bias = None | |
| self.up = None | |
| self.down = None | |
| self.in_channels = in_channels | |
| self.out_channels = out_channels | |
| self.kernel_size = kernel_size | |
| self.padding_mode = padding_mode | |
| self.device = device | |
| self.dtype = dtype | |
| self.weight = None | |
| def bind_lora(self, weight_module): | |
| self.weight_module = [weight_module] | |
| def unbind_lora(self): | |
| if self.up is not None and self.down is not None: # SAI's model is weird and needs this | |
| self.weight_module = None | |
| def get_original_weight(self): | |
| if self.weight_module is None: | |
| return None | |
| return self.weight_module[0].weight | |
| def forward(self, x): | |
| if self.weight is not None: | |
| return torch.nn.functional.conv2d(x, self.weight.to(x), self.bias.to(x) if self.bias is not None else None, | |
| self.stride, self.padding, self.dilation, self.groups) | |
| original_weight = self.get_original_weight() | |
| if original_weight is None: | |
| return None # A1111 needs first_time_calculation | |
| if self.up is not None and self.down is not None: | |
| weight = original_weight.to(x) + torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1)).reshape(original_weight.shape).to(x) | |
| else: | |
| weight = original_weight.to(x) | |
| return torch.nn.functional.conv2d(x, weight, self.bias.to(x) if self.bias is not None else None, | |
| self.stride, self.padding, self.dilation, self.groups) | |
| def controlnet_lora_hijack(): | |
| linear, conv2d = torch.nn.Linear, torch.nn.Conv2d | |
| torch.nn.Linear, torch.nn.Conv2d = LinearWithLoRA, Conv2dWithLoRA | |
| try: | |
| yield | |
| finally: | |
| torch.nn.Linear, torch.nn.Conv2d = linear, conv2d | |
| def recursive_set(obj, key, value): | |
| if obj is None: | |
| return | |
| if '.' in key: | |
| k1, k2 = key.split('.', 1) | |
| recursive_set(getattr(obj, k1, None), k2, value) | |
| else: | |
| setattr(obj, key, value) | |
| def force_load_state_dict(model, state_dict): | |
| for k in list(state_dict.keys()): | |
| recursive_set(model, k, torch.nn.Parameter(state_dict[k])) | |
| del state_dict[k] | |
| return | |
| def recursive_bind_lora(obj, key, value): | |
| if obj is None: | |
| return | |
| if '.' in key: | |
| k1, k2 = key.split('.', 1) | |
| recursive_bind_lora(getattr(obj, k1, None), k2, value) | |
| else: | |
| target = getattr(obj, key, None) | |
| if target is not None and hasattr(target, 'bind_lora'): | |
| target.bind_lora(value) | |
| def recursive_get(obj, key): | |
| if obj is None: | |
| return | |
| if '.' in key: | |
| k1, k2 = key.split('.', 1) | |
| return recursive_get(getattr(obj, k1, None), k2) | |
| else: | |
| return getattr(obj, key, None) | |
| def bind_control_lora(base_model, control_lora_model): | |
| sd = base_model.state_dict() | |
| keys = list(sd.keys()) | |
| keys = list(set([k.rsplit('.', 1)[0] for k in keys])) | |
| module_dict = {k: recursive_get(base_model, k) for k in keys} | |
| for k, v in module_dict.items(): | |
| recursive_bind_lora(control_lora_model, k, v) | |
| def torch_dfs(model: torch.nn.Module): | |
| result = [model] | |
| for child in model.children(): | |
| result += torch_dfs(child) | |
| return result | |
| def unbind_control_lora(control_lora_model): | |
| for m in torch_dfs(control_lora_model): | |
| if hasattr(m, 'unbind_lora'): | |
| m.unbind_lora() | |
| return | |