Spaces:
Build error
Build error
Update trainer.py
Browse files- trainer.py +17 -9
trainer.py
CHANGED
|
@@ -60,11 +60,13 @@ class Trainer:
|
|
| 60 |
resolution_s: str,
|
| 61 |
concept_images: list | None,
|
| 62 |
concept_prompt: str,
|
|
|
|
| 63 |
n_steps: int,
|
| 64 |
learning_rate: float,
|
| 65 |
train_text_encoder: bool,
|
| 66 |
learning_rate_text: float,
|
| 67 |
gradient_accumulation: int,
|
|
|
|
| 68 |
fp16: bool,
|
| 69 |
use_8bit_adam: bool,
|
| 70 |
) -> tuple[dict, list[pathlib.Path]]:
|
|
@@ -85,18 +87,24 @@ class Trainer:
|
|
| 85 |
self.prepare_dataset(concept_images, resolution)
|
| 86 |
|
| 87 |
command = f'''
|
| 88 |
-
accelerate launch
|
| 89 |
-
--pretrained_model_name_or_path={base_model}
|
| 90 |
-
--instance_data_dir={self.instance_data_dir}
|
|
|
|
| 91 |
--output_dir={self.output_dir} \
|
|
|
|
| 92 |
--instance_prompt="{concept_prompt}" \
|
| 93 |
-
--
|
| 94 |
-
--
|
| 95 |
-
--
|
| 96 |
-
--
|
| 97 |
-
--
|
|
|
|
| 98 |
--lr_warmup_steps=0 \
|
| 99 |
-
--max_train_steps={n_steps}
|
|
|
|
|
|
|
|
|
|
| 100 |
'''
|
| 101 |
if fp16:
|
| 102 |
command += ' --mixed_precision fp16'
|
|
|
|
| 60 |
resolution_s: str,
|
| 61 |
concept_images: list | None,
|
| 62 |
concept_prompt: str,
|
| 63 |
+
class_prompt: str,
|
| 64 |
n_steps: int,
|
| 65 |
learning_rate: float,
|
| 66 |
train_text_encoder: bool,
|
| 67 |
learning_rate_text: float,
|
| 68 |
gradient_accumulation: int,
|
| 69 |
+
batch-size: int,
|
| 70 |
fp16: bool,
|
| 71 |
use_8bit_adam: bool,
|
| 72 |
) -> tuple[dict, list[pathlib.Path]]:
|
|
|
|
| 87 |
self.prepare_dataset(concept_images, resolution)
|
| 88 |
|
| 89 |
command = f'''
|
| 90 |
+
accelerate launch custom-diffusion/src/diffuser_training.py \
|
| 91 |
+
--pretrained_model_name_or_path={base_model} \
|
| 92 |
+
--instance_data_dir={self.instance_data_dir} \
|
| 93 |
+
--class_data_dir={self.class_data_dir} \
|
| 94 |
--output_dir={self.output_dir} \
|
| 95 |
+
--with_prior_preservation --real_prior --prior_loss_weight=1.0 \
|
| 96 |
--instance_prompt="{concept_prompt}" \
|
| 97 |
+
--class_prompt="{class_prompt}" \
|
| 98 |
+
--resolution={resolution} \
|
| 99 |
+
--train_batch_size={batch-size} \
|
| 100 |
+
--gradient_accumulation_steps={gradient_accumulation} \
|
| 101 |
+
--learning_rate={learning_rate} \
|
| 102 |
+
--lr_scheduler="constant" \
|
| 103 |
--lr_warmup_steps=0 \
|
| 104 |
+
--max_train_steps={n_steps} \
|
| 105 |
+
--num_class_images=200 \
|
| 106 |
+
--scale_lr \
|
| 107 |
+
--modifier_token "<new1>"
|
| 108 |
'''
|
| 109 |
if fp16:
|
| 110 |
command += ' --mixed_precision fp16'
|