bomdey commited on
Commit
3d76fe4
·
1 Parent(s): 0921fb4

Add fine-tuning script

Browse files
Files changed (1) hide show
  1. FineTuningLora.ipynb +550 -0
FineTuningLora.ipynb ADDED
@@ -0,0 +1,550 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "nbformat": 4,
3
+ "nbformat_minor": 0,
4
+ "metadata": {
5
+ "colab": {
6
+ "provenance": [],
7
+ "gpuType": "T4"
8
+ },
9
+ "kernelspec": {
10
+ "name": "python3",
11
+ "display_name": "Python 3"
12
+ },
13
+ "language_info": {
14
+ "name": "python"
15
+ },
16
+ "accelerator": "GPU"
17
+ },
18
+ "cells": [
19
+ {
20
+ "cell_type": "markdown",
21
+ "source": [
22
+ "## Fine-Tuning script for master thesis \"*AI-based Image Generation to Support Easy Language*\"\n",
23
+ "\n",
24
+ "This is an adapted version of this [colab](https://colab.research.google.com/github/Linaqruf/kohya-trainer/blob/main/kohya-LoRA-dreambooth.ipynb) and makes use of the fine-tuning script from this [repository](https://github.com/Linaqruf/kohya-trainer) (commit: `3d494d8`).\n",
25
+ "\n",
26
+ "Execute all cells to reproduce the weights used in the thesis. T4 and disabled \"extended ram\" were used during the final training run of the thesis."
27
+ ],
28
+ "metadata": {
29
+ "id": "HWkM_jf5v42U"
30
+ }
31
+ },
32
+ {
33
+ "cell_type": "code",
34
+ "execution_count": null,
35
+ "metadata": {
36
+ "id": "nb06s6qR0FFP"
37
+ },
38
+ "outputs": [],
39
+ "source": [
40
+ "# @title ## Install Dependencies\n",
41
+ "import os\n",
42
+ "import zipfile\n",
43
+ "import shutil\n",
44
+ "import time\n",
45
+ "from subprocess import getoutput\n",
46
+ "from IPython.utils import capture\n",
47
+ "from google.colab import drive\n",
48
+ "\n",
49
+ "\n",
50
+ "%store -r\n",
51
+ "\n",
52
+ "# root_dir\n",
53
+ "root_dir = \"/content\"\n",
54
+ "repo_dir = os.path.join(root_dir, \"kohya-trainer\")\n",
55
+ "training_dir = os.path.join(root_dir, \"LoRA\")\n",
56
+ "pretrained_model = os.path.join(root_dir, \"pretrained_model\")\n",
57
+ "vae_dir = os.path.join(root_dir, \"vae\")\n",
58
+ "config_dir = os.path.join(training_dir, \"config\")\n",
59
+ "\n",
60
+ "# repo_dir\n",
61
+ "accelerate_config = os.path.join(repo_dir, \"accelerate_config/config.yaml\")\n",
62
+ "tools_dir = os.path.join(repo_dir, \"tools\")\n",
63
+ "finetune_dir = os.path.join(repo_dir, \"finetune\")\n",
64
+ "\n",
65
+ "# output_dir\n",
66
+ "output_to_drive = False\n",
67
+ "output_dir = \"/content/LoRA/output\" if not output_to_drive else \"/content/drive/MyDrive/LoRA/output\"\n",
68
+ "sample_dir = os.path.join(output_dir, \"sample\")\n",
69
+ "\n",
70
+ "for store in [\n",
71
+ " \"root_dir\",\n",
72
+ " \"repo_dir\",\n",
73
+ " \"training_dir\",\n",
74
+ " \"pretrained_model\",\n",
75
+ " \"vae_dir\",\n",
76
+ " \"accelerate_config\",\n",
77
+ " \"tools_dir\",\n",
78
+ " \"finetune_dir\",\n",
79
+ " \"config_dir\",\n",
80
+ " \"output_dir\",\n",
81
+ " \"sample_dir\"\n",
82
+ "]:\n",
83
+ " with capture.capture_output() as cap:\n",
84
+ " %store {store}\n",
85
+ " del cap\n",
86
+ "\n",
87
+ "repo_url = \"https://github.com/Linaqruf/kohya-trainer\"\n",
88
+ "submission_hash = \"3d494d83e4aea273f64716286a26d162a8df3317\"\n",
89
+ "branch = \"\"\n",
90
+ "mount_drive = True\n",
91
+ "verbose = False\n",
92
+ "\n",
93
+ "def read_file(filename):\n",
94
+ " with open(filename, \"r\") as f:\n",
95
+ " contents = f.read()\n",
96
+ " return contents\n",
97
+ "\n",
98
+ "\n",
99
+ "def write_file(filename, contents):\n",
100
+ " with open(filename, \"w\") as f:\n",
101
+ " f.write(contents)\n",
102
+ "\n",
103
+ "\n",
104
+ "def clone_repo(url):\n",
105
+ " if not os.path.exists(repo_dir):\n",
106
+ " os.chdir(root_dir)\n",
107
+ " !git clone {url} {repo_dir}\n",
108
+ " !git checkout {submission_hash}\n",
109
+ " else:\n",
110
+ " os.chdir(repo_dir)\n",
111
+ " !git checkout {submission_hash}\n",
112
+ "\n",
113
+ "def mount_drive():\n",
114
+ " if not os.path.exists(\"/content/drive\"):\n",
115
+ " drive.mount(\"/content/drive\")\n",
116
+ "\n",
117
+ "def set_environment_variables():\n",
118
+ " os.environ[\"TF_CPP_MIN_LOG_LEVEL\"] = \"3\"\n",
119
+ " os.environ[\"BITSANDBYTES_NOWELCOME\"] = \"1\"\n",
120
+ " os.environ[\"SAFETENSORS_FAST_GPU\"] = \"1\"\n",
121
+ "\n",
122
+ "def adjust_ld_library_path(cuda_path):\n",
123
+ " ld_library_path = os.environ.get(\"LD_LIBRARY_PATH\", \"\")\n",
124
+ " os.environ[\"LD_LIBRARY_PATH\"] = f\"{ld_library_path}:{cuda_path}\"\n",
125
+ "\n",
126
+ "def make_dirs():\n",
127
+ " for dir in [\n",
128
+ " training_dir,\n",
129
+ " config_dir,\n",
130
+ " pretrained_model,\n",
131
+ " vae_dir,\n",
132
+ " output_dir,\n",
133
+ " sample_dir\n",
134
+ " ]:\n",
135
+ " os.makedirs(dir, exist_ok=True)\n",
136
+ "\n",
137
+ "def install_dependencies(verbose=True, accelerate_config=\"accelerate_config.yaml\"):\n",
138
+ " \"\"\"Install all requirements and dependencies\"\"\"\n",
139
+ " gpu_info = getoutput(\"nvidia-smi\")\n",
140
+ " if \"T4\" in gpu_info:\n",
141
+ " update_gpu_configuration()\n",
142
+ "\n",
143
+ " install_requirements(verbose)\n",
144
+ " install_pytorch_libraries(verbose)\n",
145
+ "\n",
146
+ " configure_accelerate(accelerate_config)\n",
147
+ "\n",
148
+ "def update_gpu_configuration():\n",
149
+ " \"\"\"Modify the utility file to use GPU (replace 'cpu' with 'cuda')\"\"\"\n",
150
+ " !sed -i \"s@cpu@cuda@\" library/model_util.py\n",
151
+ "\n",
152
+ "def install_requirements(verbose):\n",
153
+ " \"\"\"Install Python packages from requirements.txt\"\"\"\n",
154
+ " !pip install {\"-q\" if not verbose else \"\"} --upgrade -r requirements.txt\n",
155
+ "\n",
156
+ "def install_pytorch_libraries(verbose):\n",
157
+ " \"\"\"Install specific versions of PyTorch and related libraries\"\"\"\n",
158
+ " !pip install {\"-q\" if not verbose else \"\"} torch==2.0.0+cu118 torchvision==0.15.1+cu118 torchaudio==2.0.1+cu118 torchtext==0.15.1 torchdata==0.6.0 xformers==0.0.19 triton==2.0.0 --extra-index-url https://download.pytorch.org/whl/cu118 -U\n",
159
+ "\n",
160
+ "def configure_accelerate(accelerate_config):\n",
161
+ " \"\"\"Configure Accelerate if the specified config file does not exist\"\"\"\n",
162
+ " from accelerate.utils import write_basic_config\n",
163
+ "\n",
164
+ " if not os.path.exists(accelerate_config):\n",
165
+ " write_basic_config(save_location=accelerate_config)\n",
166
+ "\n",
167
+ "\n",
168
+ "def main():\n",
169
+ " \"\"\"Setup directories and environment specific variables\"\"\"\n",
170
+ " os.chdir(root_dir)\n",
171
+ "\n",
172
+ " if mount_drive:\n",
173
+ " mount_drive()\n",
174
+ "\n",
175
+ " make_dirs()\n",
176
+ "\n",
177
+ " clone_repo(repo_url)\n",
178
+ "\n",
179
+ " os.chdir(repo_dir)\n",
180
+ "\n",
181
+ " !apt install aria2 {\"-qq\" if not verbose else \"\"}\n",
182
+ "\n",
183
+ " install_dependencies(verbose=verbose, accelerate_config=accelerate_config)\n",
184
+ " time.sleep(3)\n",
185
+ "\n",
186
+ " set_environment_variables()\n",
187
+ "\n",
188
+ " cuda_path = \"/usr/local/cuda-11.8/targets/x86_64-linux/lib/\"\n",
189
+ " adjust_ld_library_path(cuda_path)\n",
190
+ "\n",
191
+ "main()\n"
192
+ ]
193
+ },
194
+ {
195
+ "cell_type": "code",
196
+ "source": [
197
+ "# @title ## Download Model and VAE\n",
198
+ "\n",
199
+ "%store -r\n",
200
+ "\n",
201
+ "os.chdir(root_dir)\n",
202
+ "\n",
203
+ "hf_token = \"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\"\n",
204
+ "user_header = f'\"Authorization: Bearer {hf_token}\"'\n",
205
+ "\n",
206
+ "# model\n",
207
+ "model_name = \"Stable-Diffusion-v1-5.safetensors\"\n",
208
+ "model_url = \"https://huggingface.co/bomdey/plAInlang/resolve/main/stable_diffusion_1_5-pruned.safetensors\"\n",
209
+ "\n",
210
+ "# Download pretrained model from huggingface\n",
211
+ "pretrained_model_name_or_path = os.path.join(pretrained_model, model_name)\n",
212
+ "if not os.path.exists(pretrained_model_name_or_path):\n",
213
+ " !aria2c --console-log-level=error --summary-interval=10 --header={user_header} -c -x 16 -k 1M -s 16 -d {pretrained_model} -o {model_name} \"{model_url}\"\n",
214
+ "\n",
215
+ "# vae\n",
216
+ "vae_name = \"stablediffusion.vae.pt\"\n",
217
+ "vae_url = \"https://huggingface.co/bomdey/plAInlang/resolve/main/vae-ft-mse-840000-ema-pruned.ckpt\"\n",
218
+ "\n",
219
+ "# Download vae from huggingface\n",
220
+ "vae = os.path.join(vae_dir, vae_name)\n",
221
+ "if not os.path.exists(vae):\n",
222
+ " !aria2c --console-log-level=error --summary-interval=10 --header={user_header} -c -x 16 -k 1M -s 16 -d {vae_dir} -o {vae_name} \"{vae_url}\""
223
+ ],
224
+ "metadata": {
225
+ "id": "cjt6t_ob01g7"
226
+ },
227
+ "execution_count": null,
228
+ "outputs": []
229
+ },
230
+ {
231
+ "cell_type": "code",
232
+ "source": [
233
+ "# @title ## Load Dataset from Huggingface\n",
234
+ "\n",
235
+ "%store -r\n",
236
+ "\n",
237
+ "dataset_submission_hash = \"731fd74dbed6197f88d935828608fff4a3b3299d\"\n",
238
+ "hf_dataset_repo = \"https://huggingface.co/datasets/bomdey/plAInLang/\"\n",
239
+ "data_destination_dir = \"/content/dataset\"\n",
240
+ "\n",
241
+ "if not os.path.exists(data_destination_dir):\n",
242
+ " !git clone {hf_dataset_repo} {data_destination_dir}\n",
243
+ " time.sleep(3)\n",
244
+ "\n",
245
+ "os.chdir(data_destination_dir)\n",
246
+ "!git checkout {dataset_submission_hash}\n",
247
+ "\n",
248
+ "%store data_destination_dir\n",
249
+ "\n",
250
+ "os.chdir(root_dir)\n",
251
+ "\n",
252
+ "# Setup directory for training data\n",
253
+ "train_data_dir = data_destination_dir\n",
254
+ "\n",
255
+ "%store train_data_dir"
256
+ ],
257
+ "metadata": {
258
+ "id": "llTHbemwhzTv"
259
+ },
260
+ "execution_count": null,
261
+ "outputs": []
262
+ },
263
+ {
264
+ "cell_type": "code",
265
+ "source": [
266
+ "# @title ## Dataset Config\n",
267
+ "import toml\n",
268
+ "import glob\n",
269
+ "\n",
270
+ "dataset_repeats = 10\n",
271
+ "activation_word = \"pl41nl4ng\"\n",
272
+ "caption_extension = \".txt\"\n",
273
+ "resolution = 512\n",
274
+ "flip_aug = False\n",
275
+ "keep_tokens = 0\n",
276
+ "\n",
277
+ "def find_image_files(path):\n",
278
+ " \"\"\"Get all images from a given path\"\"\"\n",
279
+ " supported_extensions = (\".png\", \".jpg\", \".jpeg\", \".webp\", \".bmp\")\n",
280
+ " return [file for file in glob.glob(path + '/**/*', recursive=True) if file.lower().endswith(supported_extensions)]\n",
281
+ "\n",
282
+ "def process_data_dir(data_dir, default_num_repeats, default_class_token):\n",
283
+ " \"\"\"Process a data directory and create subsets for image datasets\"\"\"\n",
284
+ " subsets = []\n",
285
+ " images = find_image_files(data_dir)\n",
286
+ " if images:\n",
287
+ " subsets.append({\n",
288
+ " \"image_dir\": data_dir,\n",
289
+ " \"class_tokens\": default_class_token,\n",
290
+ " \"num_repeats\": default_num_repeats,\n",
291
+ " **({}),\n",
292
+ " })\n",
293
+ "\n",
294
+ " return subsets\n",
295
+ "\n",
296
+ "\n",
297
+ "train_subsets = process_data_dir(train_data_dir, dataset_repeats, activation_word)\n",
298
+ "config = {\n",
299
+ " \"general\": {\n",
300
+ " \"enable_bucket\": True,\n",
301
+ " \"caption_extension\": caption_extension,\n",
302
+ " \"shuffle_caption\": True,\n",
303
+ " \"keep_tokens\": keep_tokens,\n",
304
+ " \"bucket_reso_steps\": 64,\n",
305
+ " \"bucket_no_upscale\": False,\n",
306
+ " },\n",
307
+ " \"datasets\": [\n",
308
+ " {\n",
309
+ " \"resolution\": resolution,\n",
310
+ " \"min_bucket_reso\": 256,\n",
311
+ " \"max_bucket_reso\": 1024,\n",
312
+ " \"caption_dropout_rate\": 0,\n",
313
+ " \"caption_tag_dropout_rate\": 0,\n",
314
+ " \"caption_dropout_every_n_epochs\": 0,\n",
315
+ " \"flip_aug\": flip_aug,\n",
316
+ " \"color_aug\": False,\n",
317
+ " \"face_crop_aug_range\": None,\n",
318
+ " \"subsets\": train_subsets,\n",
319
+ " }\n",
320
+ " ],\n",
321
+ "}\n",
322
+ "\n",
323
+ "\n",
324
+ "dataset_config = os.path.join(config_dir, \"dataset_config.toml\")\n",
325
+ "config_str = toml.dumps(config)\n",
326
+ "with open(dataset_config, \"w\") as f:\n",
327
+ " f.write(config_str)\n",
328
+ "\n",
329
+ "print(config_str)"
330
+ ],
331
+ "metadata": {
332
+ "id": "7lgAaowm3uMV"
333
+ },
334
+ "execution_count": null,
335
+ "outputs": []
336
+ },
337
+ {
338
+ "cell_type": "code",
339
+ "source": [
340
+ "# @title ## Training Config\n",
341
+ "\n",
342
+ "import toml\n",
343
+ "import os\n",
344
+ "\n",
345
+ "project_name = \"pl41n-l4ng_final\"\n",
346
+ "%store project_name\n",
347
+ "\n",
348
+ "%store -r\n",
349
+ "\n",
350
+ "# Lora and Optimizer\n",
351
+ "conv_dim = 8\n",
352
+ "conv_alpha = 8\n",
353
+ "network_dim = 256\n",
354
+ "network_alpha = 256\n",
355
+ "network_weight = \"\"\n",
356
+ "network_module = \"networks.lora\"\n",
357
+ "network_args = \"\"\n",
358
+ "min_snr_gamma = 5\n",
359
+ "optimizer_type = \"AdamW8bit\" #\n",
360
+ "optimizer_args = \"\"\n",
361
+ "unet_lr = 5e-6\n",
362
+ "text_encoder_lr = 25e-7\n",
363
+ "lr_scheduler = \"constant\"\n",
364
+ "lr_warmup_steps = 0\n",
365
+ "lr_scheduler_num_cycles = 0\n",
366
+ "lr_scheduler_power = 0\n",
367
+ "\n",
368
+ "# Training\n",
369
+ "lowram = True\n",
370
+ "enable_sample_prompt = True\n",
371
+ "sampler = \"euler_a\"\n",
372
+ "noise_offset = 0.0\n",
373
+ "num_epochs = 30\n",
374
+ "vae_batch_size = 4\n",
375
+ "train_batch_size = 2\n",
376
+ "mixed_precision = \"fp16\"\n",
377
+ "save_precision = \"fp16\"\n",
378
+ "save_n_epochs_type = \"save_every_n_epochs\"\n",
379
+ "save_n_epochs_type_value = 1\n",
380
+ "save_model_as = \"safetensors\"\n",
381
+ "max_token_length = 225\n",
382
+ "clip_skip = 1\n",
383
+ "gradient_checkpointing = False\n",
384
+ "gradient_accumulation_steps = 1\n",
385
+ "seed = 42\n",
386
+ "logging_dir = \"/content/LoRA/logs\"\n",
387
+ "prior_loss_weight = 1.0\n",
388
+ "\n",
389
+ "os.chdir(repo_dir)\n",
390
+ "\n",
391
+ "sample_str = f\"\"\"\n",
392
+ " illustration in the style of pl41nl4ng, a man with glasses and a tie, solo, looking at viewer, smile, closed mouth, short hair, simple background, black background, shirt, 1boy, portrait, male focus, glasses \\\n",
393
+ " --n lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry \\\n",
394
+ " --w 512 \\\n",
395
+ " --h 512 \\\n",
396
+ " --l 7 \\\n",
397
+ " --s 28\n",
398
+ "\"\"\"\n",
399
+ "\n",
400
+ "config = {\n",
401
+ " \"model_arguments\": {\n",
402
+ " \"v2\": False,\n",
403
+ " \"v_parameterization\": False,\n",
404
+ " \"pretrained_model_name_or_path\": pretrained_model_name_or_path,\n",
405
+ " \"vae\": vae,\n",
406
+ " },\n",
407
+ " \"additional_network_arguments\": {\n",
408
+ " \"no_metadata\": False,\n",
409
+ " \"unet_lr\": float(unet_lr),\n",
410
+ " \"text_encoder_lr\": float(text_encoder_lr),\n",
411
+ " \"network_weights\": network_weight,\n",
412
+ " \"network_module\": network_module,\n",
413
+ " \"network_dim\": network_dim,\n",
414
+ " \"network_alpha\": network_alpha,\n",
415
+ " \"network_args\": None,\n",
416
+ " \"network_train_unet_only\": False,\n",
417
+ " \"network_train_text_encoder_only\": False,\n",
418
+ " \"training_comment\": None,\n",
419
+ " },\n",
420
+ " \"optimizer_arguments\": {\n",
421
+ " \"min_snr_gamma\": min_snr_gamma,\n",
422
+ " \"optimizer_type\": optimizer_type,\n",
423
+ " \"learning_rate\": unet_lr,\n",
424
+ " \"max_grad_norm\": 1.0,\n",
425
+ " \"optimizer_args\": None,\n",
426
+ " \"lr_scheduler\": lr_scheduler,\n",
427
+ " \"lr_warmup_steps\": lr_warmup_steps,\n",
428
+ " \"lr_scheduler_num_cycles\": None,\n",
429
+ " \"lr_scheduler_power\": None,\n",
430
+ " },\n",
431
+ " \"dataset_arguments\": {\n",
432
+ " \"cache_latents\": True,\n",
433
+ " \"debug_dataset\": False,\n",
434
+ " \"vae_batch_size\": vae_batch_size,\n",
435
+ " },\n",
436
+ " \"training_arguments\": {\n",
437
+ " \"output_dir\": output_dir,\n",
438
+ " \"output_name\": project_name,\n",
439
+ " \"save_precision\": save_precision,\n",
440
+ " \"save_every_n_epochs\": save_n_epochs_type_value,\n",
441
+ " \"save_n_epoch_ratio\": None,\n",
442
+ " \"save_last_n_epochs\": None,\n",
443
+ " \"save_state\": None,\n",
444
+ " \"save_last_n_epochs_state\": None,\n",
445
+ " \"resume\": None,\n",
446
+ " \"train_batch_size\": train_batch_size,\n",
447
+ " \"max_token_length\": 225,\n",
448
+ " \"mem_eff_attn\": False,\n",
449
+ " \"xformers\": True,\n",
450
+ " \"max_train_epochs\": num_epochs,\n",
451
+ " \"max_data_loader_n_workers\": 8,\n",
452
+ " \"persistent_data_loader_workers\": True,\n",
453
+ " \"seed\": seed if seed > 0 else None,\n",
454
+ " \"gradient_checkpointing\": gradient_checkpointing,\n",
455
+ " \"gradient_accumulation_steps\": gradient_accumulation_steps,\n",
456
+ " \"mixed_precision\": mixed_precision,\n",
457
+ " \"clip_skip\": clip_skip,\n",
458
+ " \"logging_dir\": logging_dir,\n",
459
+ " \"log_prefix\": project_name,\n",
460
+ " \"noise_offset\": None,\n",
461
+ " \"lowram\": lowram,\n",
462
+ " },\n",
463
+ " \"sample_prompt_arguments\": {\n",
464
+ " \"sample_every_n_steps\": None,\n",
465
+ " \"sample_every_n_epochs\": 1,\n",
466
+ " \"sample_sampler\": sampler,\n",
467
+ " },\n",
468
+ " \"dreambooth_arguments\": {\n",
469
+ " \"prior_loss_weight\": 1.0,\n",
470
+ " },\n",
471
+ " \"saving_arguments\": {\n",
472
+ " \"save_model_as\": save_model_as\n",
473
+ " },\n",
474
+ "}\n",
475
+ "\n",
476
+ "config_path = os.path.join(config_dir, \"config_file.toml\")\n",
477
+ "prompt_path = os.path.join(config_dir, \"sample_prompt.txt\")\n",
478
+ "\n",
479
+ "for key in config:\n",
480
+ " if isinstance(config[key], dict):\n",
481
+ " for sub_key in config[key]:\n",
482
+ " if config[key][sub_key] == \"\":\n",
483
+ " config[key][sub_key] = None\n",
484
+ " elif config[key] == \"\":\n",
485
+ " config[key] = None\n",
486
+ "\n",
487
+ "config_str = toml.dumps(config)\n",
488
+ "\n",
489
+ "def write_file(filename, contents):\n",
490
+ " with open(filename, \"w\") as f:\n",
491
+ " f.write(contents)\n",
492
+ "\n",
493
+ "write_file(config_path, config_str)\n",
494
+ "write_file(prompt_path, sample_str)\n",
495
+ "\n",
496
+ "print(config_str)"
497
+ ],
498
+ "metadata": {
499
+ "id": "igwhMSLQ5Dz_"
500
+ },
501
+ "execution_count": null,
502
+ "outputs": []
503
+ },
504
+ {
505
+ "cell_type": "code",
506
+ "source": [
507
+ "#@title ## Start Training\n",
508
+ "\n",
509
+ "sample_prompt = \"/content/LoRA/config/sample_prompt.txt\"\n",
510
+ "config_file = \"/content/LoRA/config/config_file.toml\"\n",
511
+ "dataset_config = \"/content/LoRA/config/dataset_config.toml\"\n",
512
+ "\n",
513
+ "accelerate_conf = {\n",
514
+ " \"config_file\" : accelerate_config,\n",
515
+ " \"num_cpu_threads_per_process\" : 1,\n",
516
+ "}\n",
517
+ "\n",
518
+ "train_conf = {\n",
519
+ " \"sample_prompts\" : sample_prompt,\n",
520
+ " \"dataset_config\" : dataset_config,\n",
521
+ " \"config_file\" : config_file\n",
522
+ "}\n",
523
+ "\n",
524
+ "def train(config):\n",
525
+ " \"\"\"Create training arguments\"\"\"\n",
526
+ " args = \"\"\n",
527
+ " for k, v in config.items():\n",
528
+ " if isinstance(v, str):\n",
529
+ " args += f'--{k}=\"{v}\" '\n",
530
+ " elif isinstance(v, int) and not isinstance(v, bool):\n",
531
+ " args += f\"--{k}={v} \"\n",
532
+ "\n",
533
+ " return args\n",
534
+ "\n",
535
+ "\n",
536
+ "accelerate_args = train(accelerate_conf)\n",
537
+ "train_args = train(train_conf)\n",
538
+ "final_args = f\"accelerate launch {accelerate_args} train_network.py {train_args}\"\n",
539
+ "\n",
540
+ "os.chdir(repo_dir)\n",
541
+ "!{final_args}"
542
+ ],
543
+ "metadata": {
544
+ "id": "0hyHFH845al3"
545
+ },
546
+ "execution_count": null,
547
+ "outputs": []
548
+ }
549
+ ]
550
+ }