lmganon123 commited on
Commit
338a28b
·
verified ·
1 Parent(s): 824dd20

Upload 3 files

Browse files
Files changed (3) hide show
  1. hypernetwork.py +787 -0
  2. pochi_4-l-500721.pt +3 -0
  3. pochi_pt3.pt +3 -0
hypernetwork.py ADDED
@@ -0,0 +1,787 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ import glob
3
+ import html
4
+ import os
5
+ import inspect
6
+ from contextlib import closing
7
+
8
+ import modules.textual_inversion.dataset
9
+ import torch
10
+ import tqdm
11
+ from einops import rearrange, repeat
12
+ from ldm.util import default
13
+ from modules import devices, sd_models, shared, sd_samplers, hashes, sd_hijack_checkpoint, errors
14
+ from modules.textual_inversion import textual_inversion, logging
15
+ from modules.textual_inversion.learn_schedule import LearnRateScheduler
16
+ from torch import einsum
17
+ from torch.nn.init import normal_, xavier_normal_, xavier_uniform_, kaiming_normal_, kaiming_uniform_, zeros_
18
+
19
+ from collections import deque
20
+ from statistics import stdev, mean
21
+
22
+
23
+ optimizer_dict = {optim_name : cls_obj for optim_name, cls_obj in inspect.getmembers(torch.optim, inspect.isclass) if optim_name != "Optimizer"}
24
+
25
+ class HypernetworkModule(torch.nn.Module):
26
+ activation_dict = {
27
+ "linear": torch.nn.Identity,
28
+ "relu": torch.nn.ReLU,
29
+ "leakyrelu": torch.nn.LeakyReLU,
30
+ "elu": torch.nn.ELU,
31
+ "swish": torch.nn.Hardswish,
32
+ "tanh": torch.nn.Tanh,
33
+ "sigmoid": torch.nn.Sigmoid,
34
+ }
35
+ activation_dict.update({cls_name.lower(): cls_obj for cls_name, cls_obj in inspect.getmembers(torch.nn.modules.activation) if inspect.isclass(cls_obj) and cls_obj.__module__ == 'torch.nn.modules.activation'})
36
+
37
+ def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, weight_init='Normal',
38
+ add_layer_norm=False, activate_output=False, dropout_structure=None):
39
+ super().__init__()
40
+
41
+ self.multiplier = 1.0
42
+
43
+ assert layer_structure is not None, "layer_structure must not be None"
44
+ assert layer_structure[0] == 1, "Multiplier Sequence should start with size 1!"
45
+ assert layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!"
46
+
47
+ linears = []
48
+ for i in range(len(layer_structure) - 1):
49
+
50
+ # Add a fully-connected layer
51
+ linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1])))
52
+
53
+
54
+ if add_layer_norm:
55
+ if (i < len(layer_structure) - 3):
56
+ linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1]), eps=1e-05, elementwise_affine=False))
57
+
58
+ # Everything should be now parsed into dropout structure, and applied here.
59
+ # Since we only have dropouts after layers, dropout structure should start with 0 and end with 0.
60
+ if dropout_structure is not None and dropout_structure[i+1] > 0:
61
+ assert 0 < dropout_structure[i+1] < 1, "Dropout probability should be 0 or float between 0 and 1!"
62
+ linears.append(torch.nn.Dropout(p=dropout_structure[i+1]))
63
+
64
+ # Add an activation func except last layer
65
+ if activation_func == "linear" or activation_func is None or (i >= len(layer_structure) - 2 and not activate_output):
66
+ pass
67
+ elif activation_func in self.activation_dict:
68
+ if (i < len(layer_structure) - 3):
69
+ linears.append(self.activation_dict[activation_func]())
70
+ else:
71
+ raise RuntimeError(f'hypernetwork uses an unsupported activation function: {activation_func}')
72
+
73
+ # Add layer normalization
74
+
75
+ # Code explanation : [1, 2, 1] -> dropout is missing when last_layer_dropout is false. [1, 2, 2, 1] -> [0, 0.3, 0, 0], when its True, [0, 0.3, 0.3, 0].
76
+
77
+ self.linear = torch.nn.Sequential(*linears)
78
+
79
+ if state_dict is not None:
80
+ self.fix_old_state_dict(state_dict)
81
+ self.load_state_dict(state_dict)
82
+ else:
83
+ for layer in self.linear:
84
+ if type(layer) == torch.nn.Linear or type(layer) == torch.nn.LayerNorm:
85
+ w, b = layer.weight.data, layer.bias.data
86
+ if weight_init == "Normal" or type(layer) == torch.nn.LayerNorm:
87
+ normal_(w, mean=0.0, std=0.01)
88
+ normal_(b, mean=0.0, std=0)
89
+ elif weight_init == 'XavierUniform':
90
+ xavier_uniform_(w)
91
+ zeros_(b)
92
+ elif weight_init == 'XavierNormal':
93
+ xavier_normal_(w)
94
+ zeros_(b)
95
+ elif weight_init == 'KaimingUniform':
96
+ kaiming_uniform_(w, nonlinearity='leaky_relu' if 'leakyrelu' == activation_func else 'relu')
97
+ zeros_(b)
98
+ elif weight_init == 'KaimingNormal':
99
+ kaiming_normal_(w, nonlinearity='leaky_relu' if 'leakyrelu' == activation_func else 'relu')
100
+ zeros_(b)
101
+ else:
102
+ raise KeyError(f"Key {weight_init} is not defined as initialization!")
103
+ self.to(devices.device)
104
+
105
+ def fix_old_state_dict(self, state_dict):
106
+ changes = {
107
+ 'linear1.bias': 'linear.0.bias',
108
+ 'linear1.weight': 'linear.0.weight',
109
+ 'linear2.bias': 'linear.1.bias',
110
+ 'linear2.weight': 'linear.1.weight',
111
+ }
112
+
113
+ for fr, to in changes.items():
114
+ x = state_dict.get(fr, None)
115
+ if x is None:
116
+ continue
117
+
118
+ del state_dict[fr]
119
+ state_dict[to] = x
120
+
121
+ def forward(self, x):
122
+ return x + self.linear(x) * (self.multiplier if not self.training else 1)
123
+
124
+ def trainables(self):
125
+ layer_structure = []
126
+ for layer in self.linear:
127
+ if type(layer) == torch.nn.Linear or type(layer) == torch.nn.LayerNorm:
128
+ layer_structure += [layer.weight, layer.bias]
129
+ return layer_structure
130
+
131
+
132
+ #param layer_structure : sequence used for length, use_dropout : controlling boolean, last_layer_dropout : for compatibility check.
133
+ def parse_dropout_structure(layer_structure, use_dropout, last_layer_dropout):
134
+ if layer_structure is None:
135
+ layer_structure = [1, 2, 1]
136
+ if not use_dropout:
137
+ return [0] * len(layer_structure)
138
+ dropout_values = [0]
139
+ dropout_values.extend([0.3] * (len(layer_structure) - 3))
140
+ if last_layer_dropout:
141
+ dropout_values.append(0.3)
142
+ else:
143
+ dropout_values.append(0)
144
+ dropout_values.append(0)
145
+ return dropout_values
146
+
147
+
148
+ class Hypernetwork:
149
+ filename = None
150
+ name = None
151
+
152
+ def __init__(self, name=None, enable_sizes=None, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, activate_output=False, **kwargs):
153
+ self.filename = None
154
+ self.name = name
155
+ self.layers = {}
156
+ self.step = 0
157
+ self.sd_checkpoint = None
158
+ self.sd_checkpoint_name = None
159
+ self.layer_structure = layer_structure
160
+ self.activation_func = activation_func
161
+ self.weight_init = weight_init
162
+ self.add_layer_norm = add_layer_norm
163
+ self.use_dropout = use_dropout
164
+ self.activate_output = activate_output
165
+ self.last_layer_dropout = kwargs.get('last_layer_dropout', True)
166
+ self.dropout_structure = kwargs.get('dropout_structure', None)
167
+ if self.dropout_structure is None:
168
+ self.dropout_structure = parse_dropout_structure(self.layer_structure, self.use_dropout, self.last_layer_dropout)
169
+ self.optimizer_name = None
170
+ self.optimizer_state_dict = None
171
+ self.optional_info = None
172
+
173
+ for size in enable_sizes or []:
174
+ self.layers[size] = (
175
+ HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init,
176
+ self.add_layer_norm, self.activate_output, dropout_structure=self.dropout_structure),
177
+ HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init,
178
+ self.add_layer_norm, self.activate_output, dropout_structure=self.dropout_structure),
179
+ )
180
+ self.eval()
181
+
182
+ def weights(self):
183
+ res = []
184
+ for layers in self.layers.values():
185
+ for layer in layers:
186
+ res += layer.parameters()
187
+ return res
188
+
189
+ def train(self, mode=True):
190
+ for layers in self.layers.values():
191
+ for layer in layers:
192
+ layer.train(mode=mode)
193
+ for param in layer.parameters():
194
+ param.requires_grad = mode
195
+
196
+ def to(self, device):
197
+ for layers in self.layers.values():
198
+ for layer in layers:
199
+ layer.to(device)
200
+
201
+ return self
202
+
203
+ def set_multiplier(self, multiplier):
204
+ for layers in self.layers.values():
205
+ for layer in layers:
206
+ layer.multiplier = multiplier
207
+
208
+ return self
209
+
210
+ def eval(self):
211
+ for layers in self.layers.values():
212
+ for layer in layers:
213
+ layer.eval()
214
+ for param in layer.parameters():
215
+ param.requires_grad = False
216
+
217
+ def save(self, filename):
218
+ state_dict = {}
219
+ optimizer_saved_dict = {}
220
+
221
+ for k, v in self.layers.items():
222
+ state_dict[k] = (v[0].state_dict(), v[1].state_dict())
223
+
224
+ state_dict['step'] = self.step
225
+ state_dict['name'] = self.name
226
+ state_dict['layer_structure'] = self.layer_structure
227
+ state_dict['activation_func'] = self.activation_func
228
+ state_dict['is_layer_norm'] = self.add_layer_norm
229
+ state_dict['weight_initialization'] = self.weight_init
230
+ state_dict['sd_checkpoint'] = self.sd_checkpoint
231
+ state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name
232
+ state_dict['activate_output'] = self.activate_output
233
+ state_dict['use_dropout'] = self.use_dropout
234
+ state_dict['dropout_structure'] = self.dropout_structure
235
+ state_dict['last_layer_dropout'] = (self.dropout_structure[-2] != 0) if self.dropout_structure is not None else self.last_layer_dropout
236
+ state_dict['optional_info'] = self.optional_info if self.optional_info else None
237
+
238
+ if self.optimizer_name is not None:
239
+ optimizer_saved_dict['optimizer_name'] = self.optimizer_name
240
+
241
+ torch.save(state_dict, filename)
242
+ if shared.opts.save_optimizer_state and self.optimizer_state_dict:
243
+ optimizer_saved_dict['hash'] = self.shorthash()
244
+ optimizer_saved_dict['optimizer_state_dict'] = self.optimizer_state_dict
245
+ torch.save(optimizer_saved_dict, filename + '.optim')
246
+
247
+ def load(self, filename):
248
+ self.filename = filename
249
+ if self.name is None:
250
+ self.name = os.path.splitext(os.path.basename(filename))[0]
251
+
252
+ state_dict = torch.load(filename, map_location='cpu')
253
+
254
+ self.layer_structure = state_dict.get('layer_structure', [1, 2, 1])
255
+ self.optional_info = state_dict.get('optional_info', None)
256
+ self.activation_func = state_dict.get('activation_func', None)
257
+ self.weight_init = state_dict.get('weight_initialization', 'Normal')
258
+ self.add_layer_norm = state_dict.get('is_layer_norm', False)
259
+ self.dropout_structure = state_dict.get('dropout_structure', None)
260
+ self.use_dropout = True if self.dropout_structure is not None and any(self.dropout_structure) else state_dict.get('use_dropout', False)
261
+ self.activate_output = state_dict.get('activate_output', True)
262
+ self.last_layer_dropout = state_dict.get('last_layer_dropout', False)
263
+ # Dropout structure should have same length as layer structure, Every digits should be in [0,1), and last digit must be 0.
264
+ if self.dropout_structure is None:
265
+ self.dropout_structure = parse_dropout_structure(self.layer_structure, self.use_dropout, self.last_layer_dropout)
266
+
267
+ if shared.opts.print_hypernet_extra:
268
+ if self.optional_info is not None:
269
+ print(f" INFO:\n {self.optional_info}\n")
270
+
271
+ print(f" Layer structure: {self.layer_structure}")
272
+ print(f" Activation function: {self.activation_func}")
273
+ print(f" Weight initialization: {self.weight_init}")
274
+ print(f" Layer norm: {self.add_layer_norm}")
275
+ print(f" Dropout usage: {self.use_dropout}" )
276
+ print(f" Activate last layer: {self.activate_output}")
277
+ print(f" Dropout structure: {self.dropout_structure}")
278
+
279
+ optimizer_saved_dict = torch.load(self.filename + '.optim', map_location='cpu') if os.path.exists(self.filename + '.optim') else {}
280
+
281
+ if self.shorthash() == optimizer_saved_dict.get('hash', None):
282
+ self.optimizer_state_dict = optimizer_saved_dict.get('optimizer_state_dict', None)
283
+ else:
284
+ self.optimizer_state_dict = None
285
+ if self.optimizer_state_dict:
286
+ self.optimizer_name = optimizer_saved_dict.get('optimizer_name', 'AdamW')
287
+ if shared.opts.print_hypernet_extra:
288
+ print("Loaded existing optimizer from checkpoint")
289
+ print(f"Optimizer name is {self.optimizer_name}")
290
+ else:
291
+ self.optimizer_name = "AdamW"
292
+ if shared.opts.print_hypernet_extra:
293
+ print("No saved optimizer exists in checkpoint")
294
+
295
+ for size, sd in state_dict.items():
296
+ if type(size) == int:
297
+ self.layers[size] = (
298
+ HypernetworkModule(size, sd[0], self.layer_structure, self.activation_func, self.weight_init,
299
+ self.add_layer_norm, self.activate_output, self.dropout_structure),
300
+ HypernetworkModule(size, sd[1], self.layer_structure, self.activation_func, self.weight_init,
301
+ self.add_layer_norm, self.activate_output, self.dropout_structure),
302
+ )
303
+
304
+ self.name = state_dict.get('name', self.name)
305
+ self.step = state_dict.get('step', 0)
306
+ self.sd_checkpoint = state_dict.get('sd_checkpoint', None)
307
+ self.sd_checkpoint_name = state_dict.get('sd_checkpoint_name', None)
308
+ self.eval()
309
+
310
+ def shorthash(self):
311
+ sha256 = hashes.sha256(self.filename, f'hypernet/{self.name}')
312
+
313
+ return sha256[0:10] if sha256 else None
314
+
315
+
316
+ def list_hypernetworks(path):
317
+ res = {}
318
+ for filename in sorted(glob.iglob(os.path.join(path, '**/*.pt'), recursive=True), key=str.lower):
319
+ name = os.path.splitext(os.path.basename(filename))[0]
320
+ # Prevent a hypothetical "None.pt" from being listed.
321
+ if name != "None":
322
+ res[name] = filename
323
+ return res
324
+
325
+
326
+ def load_hypernetwork(name):
327
+ path = shared.hypernetworks.get(name, None)
328
+
329
+ if path is None:
330
+ return None
331
+
332
+ try:
333
+ hypernetwork = Hypernetwork()
334
+ hypernetwork.load(path)
335
+ return hypernetwork
336
+ except Exception:
337
+ errors.report(f"Error loading hypernetwork {path}", exc_info=True)
338
+ return None
339
+
340
+
341
+ def load_hypernetworks(names, multipliers=None):
342
+ already_loaded = {}
343
+
344
+ for hypernetwork in shared.loaded_hypernetworks:
345
+ if hypernetwork.name in names:
346
+ already_loaded[hypernetwork.name] = hypernetwork
347
+
348
+ shared.loaded_hypernetworks.clear()
349
+
350
+ for i, name in enumerate(names):
351
+ hypernetwork = already_loaded.get(name, None)
352
+ if hypernetwork is None:
353
+ hypernetwork = load_hypernetwork(name)
354
+
355
+ if hypernetwork is None:
356
+ continue
357
+
358
+ hypernetwork.set_multiplier(multipliers[i] if multipliers else 1.0)
359
+ shared.loaded_hypernetworks.append(hypernetwork)
360
+
361
+
362
+ def apply_single_hypernetwork(hypernetwork, context_k, context_v, layer=None):
363
+ hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context_k.shape[2], None)
364
+
365
+ if hypernetwork_layers is None:
366
+ return context_k, context_v
367
+
368
+ if layer is not None:
369
+ layer.hyper_k = hypernetwork_layers[0]
370
+ layer.hyper_v = hypernetwork_layers[1]
371
+
372
+ context_k = devices.cond_cast_unet(hypernetwork_layers[0](devices.cond_cast_float(context_k)))
373
+ context_v = devices.cond_cast_unet(hypernetwork_layers[1](devices.cond_cast_float(context_v)))
374
+ return context_k, context_v
375
+
376
+
377
+ def apply_hypernetworks(hypernetworks, context, layer=None):
378
+ context_k = context
379
+ context_v = context
380
+ for hypernetwork in hypernetworks:
381
+ context_k, context_v = apply_single_hypernetwork(hypernetwork, context_k, context_v, layer)
382
+
383
+ return context_k, context_v
384
+
385
+
386
+ def attention_CrossAttention_forward(self, x, context=None, mask=None, **kwargs):
387
+ h = self.heads
388
+
389
+ q = self.to_q(x)
390
+ context = default(context, x)
391
+
392
+ context_k, context_v = apply_hypernetworks(shared.loaded_hypernetworks, context, self)
393
+ k = self.to_k(context_k)
394
+ v = self.to_v(context_v)
395
+
396
+ q, k, v = (rearrange(t, 'b n (h d) -> (b h) n d', h=h) for t in (q, k, v))
397
+
398
+ sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
399
+
400
+ if mask is not None:
401
+ mask = rearrange(mask, 'b ... -> b (...)')
402
+ max_neg_value = -torch.finfo(sim.dtype).max
403
+ mask = repeat(mask, 'b j -> (b h) () j', h=h)
404
+ sim.masked_fill_(~mask, max_neg_value)
405
+
406
+ # attention, what we cannot get enough of
407
+ attn = sim.softmax(dim=-1)
408
+
409
+ out = einsum('b i j, b j d -> b i d', attn, v)
410
+ out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
411
+ return self.to_out(out)
412
+
413
+
414
+ def stack_conds(conds):
415
+ if len(conds) == 1:
416
+ return torch.stack(conds)
417
+
418
+ # same as in reconstruct_multicond_batch
419
+ token_count = max([x.shape[0] for x in conds])
420
+ for i in range(len(conds)):
421
+ if conds[i].shape[0] != token_count:
422
+ last_vector = conds[i][-1:]
423
+ last_vector_repeated = last_vector.repeat([token_count - conds[i].shape[0], 1])
424
+ conds[i] = torch.vstack([conds[i], last_vector_repeated])
425
+
426
+ return torch.stack(conds)
427
+
428
+
429
+ def statistics(data):
430
+ if len(data) < 2:
431
+ std = 0
432
+ else:
433
+ std = stdev(data)
434
+ total_information = f"loss:{mean(data):.3f}" + u"\u00B1" + f"({std/ (len(data) ** 0.5):.3f})"
435
+ recent_data = data[-32:]
436
+ if len(recent_data) < 2:
437
+ std = 0
438
+ else:
439
+ std = stdev(recent_data)
440
+ recent_information = f"recent 32 loss:{mean(recent_data):.3f}" + u"\u00B1" + f"({std / (len(recent_data) ** 0.5):.3f})"
441
+ return total_information, recent_information
442
+
443
+
444
+ def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, dropout_structure=None):
445
+ # Remove illegal characters from name.
446
+ name = "".join( x for x in name if (x.isalnum() or x in "._- "))
447
+ assert name, "Name cannot be empty!"
448
+
449
+ fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt")
450
+ if not overwrite_old:
451
+ assert not os.path.exists(fn), f"file {fn} already exists"
452
+
453
+ if type(layer_structure) == str:
454
+ layer_structure = [float(x.strip()) for x in layer_structure.split(",")]
455
+
456
+ if use_dropout and dropout_structure and type(dropout_structure) == str:
457
+ dropout_structure = [float(x.strip()) for x in dropout_structure.split(",")]
458
+ else:
459
+ dropout_structure = [0] * len(layer_structure)
460
+
461
+ hypernet = modules.hypernetworks.hypernetwork.Hypernetwork(
462
+ name=name,
463
+ enable_sizes=[int(x) for x in enable_sizes],
464
+ layer_structure=layer_structure,
465
+ activation_func=activation_func,
466
+ weight_init=weight_init,
467
+ add_layer_norm=add_layer_norm,
468
+ use_dropout=use_dropout,
469
+ dropout_structure=dropout_structure
470
+ )
471
+ hypernet.save(fn)
472
+
473
+ shared.reload_hypernetworks()
474
+
475
+
476
+ def train_hypernetwork(id_task, hypernetwork_name: str, learn_rate: float, batch_size: int, gradient_step: int, data_root: str, log_directory: str, training_width: int, training_height: int, varsize: bool, steps: int, clip_grad_mode: str, clip_grad_value: float, shuffle_tags: bool, tag_drop_out: bool, latent_sampling_method: str, use_weight: bool, create_image_every: int, save_hypernetwork_every: int, template_filename: str, preview_from_txt2img: bool, preview_prompt: str, preview_negative_prompt: str, preview_steps: int, preview_sampler_name: str, preview_cfg_scale: float, preview_seed: int, preview_width: int, preview_height: int):
477
+ from modules import images, processing
478
+
479
+ save_hypernetwork_every = save_hypernetwork_every or 0
480
+ create_image_every = create_image_every or 0
481
+ template_file = textual_inversion.textual_inversion_templates.get(template_filename, None)
482
+ textual_inversion.validate_train_inputs(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, template_file, template_filename, steps, save_hypernetwork_every, create_image_every, log_directory, name="hypernetwork")
483
+ template_file = template_file.path
484
+
485
+ path = shared.hypernetworks.get(hypernetwork_name, None)
486
+ hypernetwork = Hypernetwork()
487
+ hypernetwork.load(path)
488
+ shared.loaded_hypernetworks = [hypernetwork]
489
+
490
+ shared.state.job = "train-hypernetwork"
491
+ shared.state.textinfo = "Initializing hypernetwork training..."
492
+ shared.state.job_count = steps
493
+
494
+ hypernetwork_name = hypernetwork_name.rsplit('(', 1)[0]
495
+ filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
496
+
497
+ log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%m-%d"), hypernetwork_name)
498
+ unload = shared.opts.unload_models_when_training
499
+
500
+ if save_hypernetwork_every > 0:
501
+ hypernetwork_dir = os.path.join(log_directory, "hypernetworks")
502
+ os.makedirs(hypernetwork_dir, exist_ok=True)
503
+ else:
504
+ hypernetwork_dir = None
505
+
506
+ if create_image_every > 0:
507
+ images_dir = os.path.join(log_directory, "images")
508
+ os.makedirs(images_dir, exist_ok=True)
509
+ else:
510
+ images_dir = None
511
+
512
+ checkpoint = sd_models.select_checkpoint()
513
+
514
+ initial_step = hypernetwork.step or 0
515
+ if initial_step >= steps:
516
+ shared.state.textinfo = "Model has already been trained beyond specified max steps"
517
+ return hypernetwork, filename
518
+
519
+ scheduler = LearnRateScheduler(learn_rate, steps, initial_step)
520
+
521
+ clip_grad = torch.nn.utils.clip_grad_value_ if clip_grad_mode == "value" else torch.nn.utils.clip_grad_norm_ if clip_grad_mode == "norm" else None
522
+ if clip_grad:
523
+ clip_grad_sched = LearnRateScheduler(clip_grad_value, steps, initial_step, verbose=False)
524
+
525
+ if shared.opts.training_enable_tensorboard:
526
+ tensorboard_writer = textual_inversion.tensorboard_setup(log_directory)
527
+
528
+ # dataset loading may take a while, so input validations and early returns should be done before this
529
+ shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
530
+
531
+ pin_memory = shared.opts.pin_memory
532
+
533
+ ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method, varsize=varsize, use_weight=use_weight)
534
+
535
+ if shared.opts.save_training_settings_to_txt:
536
+ saved_params = dict(
537
+ model_name=checkpoint.model_name, model_hash=checkpoint.shorthash, num_of_dataset_images=len(ds),
538
+ **{field: getattr(hypernetwork, field) for field in ['layer_structure', 'activation_func', 'weight_init', 'add_layer_norm', 'use_dropout', ]}
539
+ )
540
+ logging.save_settings_to_file(log_directory, {**saved_params, **locals()})
541
+
542
+ latent_sampling_method = ds.latent_sampling_method
543
+
544
+ dl = modules.textual_inversion.dataset.PersonalizedDataLoader(ds, latent_sampling_method=latent_sampling_method, batch_size=ds.batch_size, pin_memory=pin_memory)
545
+
546
+ old_parallel_processing_allowed = shared.parallel_processing_allowed
547
+
548
+ if unload:
549
+ shared.parallel_processing_allowed = False
550
+ shared.sd_model.cond_stage_model.to(devices.cpu)
551
+ shared.sd_model.first_stage_model.to(devices.cpu)
552
+
553
+ weights = hypernetwork.weights()
554
+ hypernetwork.train()
555
+
556
+ # Here we use optimizer from saved HN, or we can specify as UI option.
557
+ if hypernetwork.optimizer_name in optimizer_dict:
558
+ optimizer = optimizer_dict[hypernetwork.optimizer_name](params=weights, lr=scheduler.learn_rate)
559
+ optimizer_name = hypernetwork.optimizer_name
560
+ else:
561
+ print(f"Optimizer type {hypernetwork.optimizer_name} is not defined!")
562
+ optimizer = torch.optim.AdamW(params=weights, lr=scheduler.learn_rate)
563
+ optimizer_name = 'AdamW'
564
+
565
+ if hypernetwork.optimizer_state_dict: # This line must be changed if Optimizer type can be different from saved optimizer.
566
+ try:
567
+ optimizer.load_state_dict(hypernetwork.optimizer_state_dict)
568
+ except RuntimeError as e:
569
+ print("Cannot resume from saved optimizer!")
570
+ print(e)
571
+
572
+ scaler = torch.cuda.amp.GradScaler()
573
+
574
+ batch_size = ds.batch_size
575
+ gradient_step = ds.gradient_step
576
+ # n steps = batch_size * gradient_step * n image processed
577
+ steps_per_epoch = len(ds) // batch_size // gradient_step
578
+ max_steps_per_epoch = len(ds) // batch_size - (len(ds) // batch_size) % gradient_step
579
+ loss_step = 0
580
+ _loss_step = 0 #internal
581
+ # size = len(ds.indexes)
582
+ # loss_dict = defaultdict(lambda : deque(maxlen = 1024))
583
+ loss_logging = deque(maxlen=len(ds) * 3) # this should be configurable parameter, this is 3 * epoch(dataset size)
584
+ # losses = torch.zeros((size,))
585
+ # previous_mean_losses = [0]
586
+ # previous_mean_loss = 0
587
+ # print("Mean loss of {} elements".format(size))
588
+
589
+ steps_without_grad = 0
590
+
591
+ last_saved_file = "<none>"
592
+ last_saved_image = "<none>"
593
+ forced_filename = "<none>"
594
+
595
+ pbar = tqdm.tqdm(total=steps - initial_step)
596
+ try:
597
+ sd_hijack_checkpoint.add()
598
+
599
+ for _ in range((steps-initial_step) * gradient_step):
600
+ if scheduler.finished:
601
+ break
602
+ if shared.state.interrupted:
603
+ break
604
+ for j, batch in enumerate(dl):
605
+ # works as a drop_last=True for gradient accumulation
606
+ if j == max_steps_per_epoch:
607
+ break
608
+ scheduler.apply(optimizer, hypernetwork.step)
609
+ if scheduler.finished:
610
+ break
611
+ if shared.state.interrupted:
612
+ break
613
+
614
+ if clip_grad:
615
+ clip_grad_sched.step(hypernetwork.step)
616
+
617
+ with devices.autocast():
618
+ x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
619
+ if use_weight:
620
+ w = batch.weight.to(devices.device, non_blocking=pin_memory)
621
+ if tag_drop_out != 0 or shuffle_tags:
622
+ shared.sd_model.cond_stage_model.to(devices.device)
623
+ c = shared.sd_model.cond_stage_model(batch.cond_text).to(devices.device, non_blocking=pin_memory)
624
+ shared.sd_model.cond_stage_model.to(devices.cpu)
625
+ else:
626
+ c = stack_conds(batch.cond).to(devices.device, non_blocking=pin_memory)
627
+ if use_weight:
628
+ loss = shared.sd_model.weighted_forward(x, c, w)[0] / gradient_step
629
+ del w
630
+ else:
631
+ loss = shared.sd_model.forward(x, c)[0] / gradient_step
632
+ del x
633
+ del c
634
+
635
+ _loss_step += loss.item()
636
+ scaler.scale(loss).backward()
637
+
638
+ # go back until we reach gradient accumulation steps
639
+ if (j + 1) % gradient_step != 0:
640
+ continue
641
+ loss_logging.append(_loss_step)
642
+ if clip_grad:
643
+ clip_grad(weights, clip_grad_sched.learn_rate)
644
+
645
+ scaler.step(optimizer)
646
+ scaler.update()
647
+ hypernetwork.step += 1
648
+ pbar.update()
649
+ optimizer.zero_grad(set_to_none=True)
650
+ loss_step = _loss_step
651
+ _loss_step = 0
652
+
653
+ steps_done = hypernetwork.step + 1
654
+
655
+ epoch_num = hypernetwork.step // steps_per_epoch
656
+ epoch_step = hypernetwork.step % steps_per_epoch
657
+
658
+ description = f"Training hypernetwork [Epoch {epoch_num}: {epoch_step+1}/{steps_per_epoch}]loss: {loss_step:.7f}"
659
+ pbar.set_description(description)
660
+ if hypernetwork_dir is not None and steps_done % save_hypernetwork_every == 0:
661
+ # Before saving, change name to match current checkpoint.
662
+ hypernetwork_name_every = f'{hypernetwork_name}-{steps_done}'
663
+ last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name_every}.pt')
664
+ hypernetwork.optimizer_name = optimizer_name
665
+ if shared.opts.save_optimizer_state:
666
+ hypernetwork.optimizer_state_dict = optimizer.state_dict()
667
+ save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, last_saved_file)
668
+ hypernetwork.optimizer_state_dict = None # dereference it after saving, to save memory.
669
+
670
+
671
+
672
+ if shared.opts.training_enable_tensorboard:
673
+ epoch_num = hypernetwork.step // len(ds)
674
+ epoch_step = hypernetwork.step - (epoch_num * len(ds)) + 1
675
+ mean_loss = sum(loss_logging) / len(loss_logging)
676
+ textual_inversion.tensorboard_add(tensorboard_writer, loss=mean_loss, global_step=hypernetwork.step, step=epoch_step, learn_rate=scheduler.learn_rate, epoch_num=epoch_num)
677
+
678
+ textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, steps_per_epoch, {
679
+ "loss": f"{loss_step:.7f}",
680
+ "learn_rate": scheduler.learn_rate
681
+ })
682
+
683
+ if images_dir is not None and steps_done % create_image_every == 0:
684
+ forced_filename = f'{hypernetwork_name}-{steps_done}'
685
+ last_saved_image = os.path.join(images_dir, forced_filename)
686
+ hypernetwork.eval()
687
+ rng_state = torch.get_rng_state()
688
+ cuda_rng_state = None
689
+ if torch.cuda.is_available():
690
+ cuda_rng_state = torch.cuda.get_rng_state_all()
691
+ shared.sd_model.cond_stage_model.to(devices.device)
692
+ shared.sd_model.first_stage_model.to(devices.device)
693
+
694
+ p = processing.StableDiffusionProcessingTxt2Img(
695
+ sd_model=shared.sd_model,
696
+ do_not_save_grid=True,
697
+ do_not_save_samples=True,
698
+ )
699
+
700
+ p.disable_extra_networks = True
701
+
702
+ if preview_from_txt2img:
703
+ p.prompt = preview_prompt
704
+ p.negative_prompt = preview_negative_prompt
705
+ p.steps = preview_steps
706
+ p.sampler_name = sd_samplers.samplers_map[preview_sampler_name.lower()]
707
+ p.cfg_scale = preview_cfg_scale
708
+ p.seed = preview_seed
709
+ p.width = preview_width
710
+ p.height = preview_height
711
+ else:
712
+ p.prompt = batch.cond_text[0]
713
+ p.steps = 20
714
+ p.width = training_width
715
+ p.height = training_height
716
+
717
+ preview_text = p.prompt
718
+
719
+ with closing(p):
720
+ processed = processing.process_images(p)
721
+ image = processed.images[0] if len(processed.images) > 0 else None
722
+
723
+ if unload:
724
+ shared.sd_model.cond_stage_model.to(devices.cpu)
725
+ shared.sd_model.first_stage_model.to(devices.cpu)
726
+ torch.set_rng_state(rng_state)
727
+ if torch.cuda.is_available():
728
+ torch.cuda.set_rng_state_all(cuda_rng_state)
729
+ hypernetwork.train()
730
+ if image is not None:
731
+ shared.state.assign_current_image(image)
732
+ if shared.opts.training_enable_tensorboard and shared.opts.training_tensorboard_save_images:
733
+ textual_inversion.tensorboard_add_image(tensorboard_writer,
734
+ f"Validation at epoch {epoch_num}", image,
735
+ hypernetwork.step)
736
+ last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
737
+ last_saved_image += f", prompt: {preview_text}"
738
+
739
+ shared.state.job_no = hypernetwork.step
740
+
741
+ shared.state.textinfo = f"""
742
+ <p>
743
+ Loss: {loss_step:.7f}<br/>
744
+ Step: {steps_done}<br/>
745
+ Last prompt: {html.escape(batch.cond_text[0])}<br/>
746
+ Last saved hypernetwork: {html.escape(last_saved_file)}<br/>
747
+ Last saved image: {html.escape(last_saved_image)}<br/>
748
+ </p>
749
+ """
750
+ except Exception:
751
+ errors.report("Exception in training hypernetwork", exc_info=True)
752
+ finally:
753
+ pbar.leave = False
754
+ pbar.close()
755
+ hypernetwork.eval()
756
+ sd_hijack_checkpoint.remove()
757
+
758
+
759
+
760
+ filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
761
+ hypernetwork.optimizer_name = optimizer_name
762
+ if shared.opts.save_optimizer_state:
763
+ hypernetwork.optimizer_state_dict = optimizer.state_dict()
764
+ save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename)
765
+
766
+ del optimizer
767
+ hypernetwork.optimizer_state_dict = None # dereference it after saving, to save memory.
768
+ shared.sd_model.cond_stage_model.to(devices.device)
769
+ shared.sd_model.first_stage_model.to(devices.device)
770
+ shared.parallel_processing_allowed = old_parallel_processing_allowed
771
+
772
+ return hypernetwork, filename
773
+
774
+ def save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename):
775
+ old_hypernetwork_name = hypernetwork.name
776
+ old_sd_checkpoint = hypernetwork.sd_checkpoint if hasattr(hypernetwork, "sd_checkpoint") else None
777
+ old_sd_checkpoint_name = hypernetwork.sd_checkpoint_name if hasattr(hypernetwork, "sd_checkpoint_name") else None
778
+ try:
779
+ hypernetwork.sd_checkpoint = checkpoint.shorthash
780
+ hypernetwork.sd_checkpoint_name = checkpoint.model_name
781
+ hypernetwork.name = hypernetwork_name
782
+ hypernetwork.save(filename)
783
+ except:
784
+ hypernetwork.sd_checkpoint = old_sd_checkpoint
785
+ hypernetwork.sd_checkpoint_name = old_sd_checkpoint_name
786
+ hypernetwork.name = old_hypernetwork_name
787
+ raise
pochi_4-l-500721.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:907acee5500505c3944f6082ced0e7de2cea41266c632e59db88fdaa3a1fc9e1
3
+ size 1754881537
pochi_pt3.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7f3d79ddd63f0c3f7a0abb5f00677b7619e77022bc65c21db7a4c35c7fcdee94
3
+ size 1930261473