speeeeed
Browse files- samples/2b_192x384_0.jpg +3 -0
- samples/2b_256x384_0.jpg +3 -0
- samples/2b_320x384_0.jpg +3 -0
- samples/2b_384x192_0.jpg +3 -0
- samples/2b_384x256_0.jpg +3 -0
- samples/2b_384x320_0.jpg +3 -0
- samples/2b_384x384_0.jpg +3 -0
- train.py +5 -5
samples/2b_192x384_0.jpg
ADDED
|
Git LFS Details
|
samples/2b_256x384_0.jpg
ADDED
|
Git LFS Details
|
samples/2b_320x384_0.jpg
ADDED
|
Git LFS Details
|
samples/2b_384x192_0.jpg
ADDED
|
Git LFS Details
|
samples/2b_384x256_0.jpg
ADDED
|
Git LFS Details
|
samples/2b_384x320_0.jpg
ADDED
|
Git LFS Details
|
samples/2b_384x384_0.jpg
ADDED
|
Git LFS Details
|
train.py
CHANGED
|
@@ -27,19 +27,19 @@ import torch.nn.functional as F
|
|
| 27 |
# --------------------------- Параметры ---------------------------
|
| 28 |
ds_path = "datasets/384"
|
| 29 |
project = "2b"
|
| 30 |
-
batch_size =
|
| 31 |
base_learning_rate = 8e-5
|
| 32 |
min_learning_rate = 4e-5
|
| 33 |
-
num_epochs =
|
| 34 |
# samples/save per epoch
|
| 35 |
-
sample_interval_share =
|
| 36 |
use_wandb = True
|
| 37 |
save_model = True
|
| 38 |
use_decay = True
|
| 39 |
fbp = False # fused backward pass
|
| 40 |
optimizer_type = "adam8bit"
|
| 41 |
torch_compile = False
|
| 42 |
-
unet_gradient =
|
| 43 |
clip_sample = False #Scheduler
|
| 44 |
fixed_seed = False
|
| 45 |
shuffle = True
|
|
@@ -400,7 +400,7 @@ if os.path.isdir(latest_checkpoint):
|
|
| 400 |
if torch_compile:
|
| 401 |
print("compiling")
|
| 402 |
torch.set_float32_matmul_precision('high')
|
| 403 |
-
unet = torch.compile(unet
|
| 404 |
print("compiling - ok")
|
| 405 |
|
| 406 |
if lora_name:
|
|
|
|
| 27 |
# --------------------------- Параметры ---------------------------
|
| 28 |
ds_path = "datasets/384"
|
| 29 |
project = "2b"
|
| 30 |
+
batch_size = 16 #50
|
| 31 |
base_learning_rate = 8e-5
|
| 32 |
min_learning_rate = 4e-5
|
| 33 |
+
num_epochs = 10
|
| 34 |
# samples/save per epoch
|
| 35 |
+
sample_interval_share = 10
|
| 36 |
use_wandb = True
|
| 37 |
save_model = True
|
| 38 |
use_decay = True
|
| 39 |
fbp = False # fused backward pass
|
| 40 |
optimizer_type = "adam8bit"
|
| 41 |
torch_compile = False
|
| 42 |
+
unet_gradient = False
|
| 43 |
clip_sample = False #Scheduler
|
| 44 |
fixed_seed = False
|
| 45 |
shuffle = True
|
|
|
|
| 400 |
if torch_compile:
|
| 401 |
print("compiling")
|
| 402 |
torch.set_float32_matmul_precision('high')
|
| 403 |
+
unet = torch.compile(unet, mode="reduce-overhead", fullgraph=False)
|
| 404 |
print("compiling - ok")
|
| 405 |
|
| 406 |
if lora_name:
|