Danrisi commited on
Commit
d60dca8
·
verified ·
1 Parent(s): 7c8e751

Upload lora.py

Browse files
Files changed (1) hide show
  1. misc/comfy/weight_adapter/lora.py +217 -0
misc/comfy/weight_adapter/lora.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import Optional
3
+
4
+ import torch
5
+ import comfy.model_management
6
+ from .base import (
7
+ WeightAdapterBase,
8
+ WeightAdapterTrainBase,
9
+ weight_decompose,
10
+ pad_tensor_to_shape,
11
+ tucker_weight_from_conv,
12
+ )
13
+
14
+
15
+ class LoraDiff(WeightAdapterTrainBase):
16
+ def __init__(self, weights):
17
+ super().__init__()
18
+ mat1, mat2, alpha, mid, dora_scale, reshape = weights
19
+ out_dim, rank = mat1.shape[0], mat1.shape[1]
20
+ rank, in_dim = mat2.shape[0], mat2.shape[1]
21
+ if mid is not None:
22
+ convdim = mid.ndim - 2
23
+ layer = (
24
+ torch.nn.Conv1d,
25
+ torch.nn.Conv2d,
26
+ torch.nn.Conv3d
27
+ )[convdim]
28
+ else:
29
+ layer = torch.nn.Linear
30
+ self.lora_up = layer(rank, out_dim, bias=False)
31
+ self.lora_down = layer(in_dim, rank, bias=False)
32
+ self.lora_up.weight.data.copy_(mat1)
33
+ self.lora_down.weight.data.copy_(mat2)
34
+ if mid is not None:
35
+ self.lora_mid = layer(mid, rank, bias=False)
36
+ self.lora_mid.weight.data.copy_(mid)
37
+ else:
38
+ self.lora_mid = None
39
+ self.rank = rank
40
+ self.alpha = torch.nn.Parameter(torch.tensor(alpha), requires_grad=False)
41
+
42
+ def __call__(self, w):
43
+ org_dtype = w.dtype
44
+ if self.lora_mid is None:
45
+ diff = self.lora_up.weight @ self.lora_down.weight
46
+ else:
47
+ diff = tucker_weight_from_conv(
48
+ self.lora_up.weight, self.lora_down.weight, self.lora_mid.weight
49
+ )
50
+ scale = self.alpha / self.rank
51
+ weight = w + scale * diff.reshape(w.shape)
52
+ return weight.to(org_dtype)
53
+
54
+ def passive_memory_usage(self):
55
+ return sum(param.numel() * param.element_size() for param in self.parameters())
56
+
57
+
58
+ class LoRAAdapter(WeightAdapterBase):
59
+ name = "lora"
60
+
61
+ def __init__(self, loaded_keys, weights):
62
+ self.loaded_keys = loaded_keys
63
+ self.weights = weights
64
+
65
+ @classmethod
66
+ def create_train(cls, weight, rank=1, alpha=1.0):
67
+ out_dim = weight.shape[0]
68
+ in_dim = weight.shape[1:].numel()
69
+ mat1 = torch.empty(out_dim, rank, device=weight.device, dtype=torch.float32)
70
+ mat2 = torch.empty(rank, in_dim, device=weight.device, dtype=torch.float32)
71
+ torch.nn.init.kaiming_uniform_(mat1, a=5**0.5)
72
+ torch.nn.init.constant_(mat2, 0.0)
73
+ return LoraDiff(
74
+ (mat1, mat2, alpha, None, None, None)
75
+ )
76
+
77
+ def to_train(self):
78
+ return LoraDiff(self.weights)
79
+
80
+ @classmethod
81
+ def load(
82
+ cls,
83
+ x: str,
84
+ lora: dict[str, torch.Tensor],
85
+ alpha: float,
86
+ dora_scale: torch.Tensor,
87
+ loaded_keys: set[str] = None,
88
+ ) -> Optional["LoRAAdapter"]:
89
+ if loaded_keys is None:
90
+ loaded_keys = set()
91
+
92
+ reshape_name = "{}.reshape_weight".format(x)
93
+ regular_lora = "{}.lora_up.weight".format(x)
94
+ diffusers_lora = "{}_lora.up.weight".format(x)
95
+ diffusers2_lora = "{}.lora_B.weight".format(x)
96
+ diffusers3_lora = "{}.lora.up.weight".format(x)
97
+ mochi_lora = "{}.lora_B".format(x)
98
+ transformers_lora = "{}.lora_linear_layer.up.weight".format(x)
99
+ qwen_default_lora = "{}.lora_B.default.weight".format(x)
100
+ chroma_radiance_lora = "{}.lora.lora_B".format(x)
101
+ A_name = None
102
+
103
+ if regular_lora in lora.keys():
104
+ A_name = regular_lora
105
+ B_name = "{}.lora_down.weight".format(x)
106
+ mid_name = "{}.lora_mid.weight".format(x)
107
+ elif diffusers_lora in lora.keys():
108
+ A_name = diffusers_lora
109
+ B_name = "{}_lora.down.weight".format(x)
110
+ mid_name = None
111
+ elif diffusers2_lora in lora.keys():
112
+ A_name = diffusers2_lora
113
+ B_name = "{}.lora_A.weight".format(x)
114
+ mid_name = None
115
+ elif diffusers3_lora in lora.keys():
116
+ A_name = diffusers3_lora
117
+ B_name = "{}.lora.down.weight".format(x)
118
+ mid_name = None
119
+ elif mochi_lora in lora.keys():
120
+ A_name = mochi_lora
121
+ B_name = "{}.lora_A".format(x)
122
+ mid_name = None
123
+ elif transformers_lora in lora.keys():
124
+ A_name = transformers_lora
125
+ B_name = "{}.lora_linear_layer.down.weight".format(x)
126
+ mid_name = None
127
+ elif qwen_default_lora in lora.keys():
128
+ A_name = qwen_default_lora
129
+ B_name = "{}.lora_A.default.weight".format(x)
130
+ mid_name = None
131
+ elif chroma_radiance_lora in lora.keys():
132
+ A_name = chroma_radiance_lora
133
+ B_name = "{}.lora.lora_A".format(x)
134
+ mid_name = None
135
+
136
+ if A_name is not None:
137
+ mid = None
138
+ if mid_name is not None and mid_name in lora.keys():
139
+ mid = lora[mid_name]
140
+ loaded_keys.add(mid_name)
141
+ reshape = None
142
+ if reshape_name in lora.keys():
143
+ try:
144
+ reshape = lora[reshape_name].tolist()
145
+ loaded_keys.add(reshape_name)
146
+ except:
147
+ pass
148
+ weights = (lora[A_name], lora[B_name], alpha, mid, dora_scale, reshape)
149
+ loaded_keys.add(A_name)
150
+ loaded_keys.add(B_name)
151
+ return cls(loaded_keys, weights)
152
+ else:
153
+ return None
154
+
155
+ def calculate_weight(
156
+ self,
157
+ weight,
158
+ key,
159
+ strength,
160
+ strength_model,
161
+ offset,
162
+ function,
163
+ intermediate_dtype=torch.float32,
164
+ original_weight=None,
165
+ ):
166
+ v = self.weights
167
+ mat1 = comfy.model_management.cast_to_device(
168
+ v[0], weight.device, intermediate_dtype
169
+ )
170
+ mat2 = comfy.model_management.cast_to_device(
171
+ v[1], weight.device, intermediate_dtype
172
+ )
173
+ dora_scale = v[4]
174
+ reshape = v[5]
175
+
176
+ if reshape is not None:
177
+ weight = pad_tensor_to_shape(weight, reshape)
178
+
179
+ if v[2] is not None:
180
+ alpha = v[2] / mat2.shape[0]
181
+ else:
182
+ alpha = 1.0
183
+
184
+ if v[3] is not None:
185
+ # locon mid weights, hopefully the math is fine because I didn't properly test it
186
+ mat3 = comfy.model_management.cast_to_device(
187
+ v[3], weight.device, intermediate_dtype
188
+ )
189
+ final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]]
190
+ mat2 = (
191
+ torch.mm(
192
+ mat2.transpose(0, 1).flatten(start_dim=1),
193
+ mat3.transpose(0, 1).flatten(start_dim=1),
194
+ )
195
+ .reshape(final_shape)
196
+ .transpose(0, 1)
197
+ )
198
+ try:
199
+ lora_diff = torch.mm(
200
+ mat1.flatten(start_dim=1), mat2.flatten(start_dim=1)
201
+ ).reshape(weight.shape)
202
+ del mat1, mat2
203
+ if dora_scale is not None:
204
+ weight = weight_decompose(
205
+ dora_scale,
206
+ weight,
207
+ lora_diff,
208
+ alpha,
209
+ strength,
210
+ intermediate_dtype,
211
+ function,
212
+ )
213
+ else:
214
+ weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
215
+ except Exception as e:
216
+ logging.error("ERROR {} {} {}".format(self.name, key, e))
217
+ return weight