Upload train_dreambooth_lora_sd3.py
Browse files
train_dreambooth_lora_sd3.py
CHANGED
|
@@ -791,7 +791,11 @@ class DreamBoothDataset(Dataset):
|
|
| 791 |
if class_data_root is not None:
|
| 792 |
self.class_data_root = Path(class_data_root)
|
| 793 |
self.class_data_root.mkdir(parents=True, exist_ok=True)
|
| 794 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 795 |
if class_num is not None:
|
| 796 |
self.num_class_images = min(len(self.class_images_path), class_num)
|
| 797 |
else:
|
|
|
|
| 791 |
if class_data_root is not None:
|
| 792 |
self.class_data_root = Path(class_data_root)
|
| 793 |
self.class_data_root.mkdir(parents=True, exist_ok=True)
|
| 794 |
+
|
| 795 |
+
#self.class_images_path = list(self.class_data_root.iterdir())
|
| 796 |
+
|
| 797 |
+
self.class_images_path = [p for p in self.class_data_root.iterdir() if p.suffix.lower() in {'.jpg', '.jpeg', '.png'}]
|
| 798 |
+
|
| 799 |
if class_num is not None:
|
| 800 |
self.num_class_images = min(len(self.class_images_path), class_num)
|
| 801 |
else:
|