recoilme commited on
Commit
bccc5a1
·
1 Parent(s): d4bdd44
.gitattributes CHANGED
@@ -33,10 +33,8 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
- tokenizer/tokenizer.json filter=lfs diff=lfs merge=lfs -text
37
- media/refined.jpg filter=lfs diff=lfs merge=lfs -text
38
- test.ipynb filter=lfs diff=lfs merge=lfs -text
39
- wandb/run-20260428_171645-wt40fdyx/run-wt40fdyx.wandb filter=lfs diff=lfs merge=lfs -text
40
- wandb/run-20260502_205213-nj3nqkga/run-nj3nqkga.wandb filter=lfs diff=lfs merge=lfs -text
41
- wandb/run-20260504_065935-dzvbyo3j/run-dzvbyo3j.wandb filter=lfs diff=lfs merge=lfs -text
42
- wandb/run-20260505_075313-ti70f47q/run-ti70f47q.wandb filter=lfs diff=lfs merge=lfs -text
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.jpg filter=lfs diff=lfs merge=lfs -text
37
+ *.png filter=lfs diff=lfs merge=lfs -text
38
+ *.ipynb filter=lfs diff=lfs merge=lfs -text
39
+ *.json filter=lfs diff=lfs merge=lfs -text
40
+ media/refined.webp filter=lfs diff=lfs merge=lfs -text
 
 
.gitignore CHANGED
@@ -11,11 +11,3 @@ datasets
11
  test
12
  wandb
13
  nohup.out
14
- samples/
15
- transformer/
16
- *.jpg
17
- *.png
18
- datasets/
19
- samples/
20
- *.jpg
21
- train.py
 
11
  test
12
  wandb
13
  nohup.out
 
 
 
 
 
 
 
 
dataset-sdxs2b-1152.py ADDED
@@ -0,0 +1,300 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pip install flash-attn --no-build-isolation
2
+ import torch
3
+ import os
4
+ import gc
5
+ import numpy as np
6
+ import random
7
+ import json
8
+ import shutil
9
+ import time
10
+
11
+ from datasets import Dataset, load_from_disk, concatenate_datasets
12
+ from diffusers import AutoencoderKLQwenImage
13
+ from torchvision.transforms import Resize, ToTensor, Normalize, Compose, InterpolationMode, Lambda
14
+ from transformers import AutoModel, AutoImageProcessor, AutoTokenizer, AutoModelForCausalLM
15
+ from typing import Dict, List, Tuple, Optional, Any
16
+ from PIL import Image
17
+ from tqdm import tqdm
18
+ from datetime import timedelta
19
+ from accelerate import Accelerator
20
+
21
+ accelerator = Accelerator()
22
+ device = accelerator.device
23
+ is_main_process = accelerator.is_main_process
24
+ process_index = accelerator.process_index
25
+ num_processes = accelerator.num_processes
26
+
27
+ # ---------------- 1️⃣ Настройки ----------------
28
+ dtype = torch.float16
29
+ batch_size = 5
30
+ min_size = 576
31
+ max_size = 1152
32
+ step = 64
33
+ empty_share = 0.0
34
+ limit = 0
35
+
36
+ folder_path = "/root/dataset"
37
+ save_path = "/root/ds1234_1152_vae_qwen"
38
+ os.makedirs(save_path, exist_ok=True)
39
+
40
+ def clear_cuda_memory():
41
+ if torch.cuda.is_available():
42
+ used_gb = torch.cuda.max_memory_allocated() / 1024**3
43
+ print(f"[GPU {process_index}] used_gb: {used_gb:.2f} GB")
44
+ torch.cuda.empty_cache()
45
+ gc.collect()
46
+
47
+ # ---------------- 2️⃣ Загрузка моделей ----------------
48
+ def load_models():
49
+ print(f"[GPU {process_index}] Загрузка моделей...")
50
+ vae = AutoencoderKLQwenImage.from_pretrained("vae", torch_dtype=dtype).to(device).eval()
51
+ return vae
52
+
53
+ vae = load_models()
54
+
55
+ shift_factor = getattr(vae.config, "shift_factor", 0.0) or 0.0
56
+ scaling_factor = getattr(vae.config, "scaling_factor", 1.0) or 1.0
57
+
58
+ mean = getattr(vae.config, "latents_mean", None)
59
+ std = getattr(vae.config, "latents_std", None)
60
+ if mean is not None and std is not None:
61
+ latents_std = torch.tensor(std, device=device, dtype=dtype).view(1, len(std), 1, 1, 1)
62
+ latents_mean = torch.tensor(mean, device=device, dtype=dtype).view(1, len(mean), 1, 1, 1)
63
+
64
+ # ---------------- 3️⃣ Трансформации ----------------
65
+ def get_image_transform(min_size=256, max_size=512, step=64):
66
+ def transform(img, dry_run=False):
67
+ original_width, original_height = img.size
68
+
69
+ if original_width >= original_height:
70
+ new_width = max_size
71
+ new_height = int(max_size * original_height / original_width)
72
+ else:
73
+ new_height = max_size
74
+ new_width = int(max_size * original_width / original_height)
75
+
76
+ if new_height < min_size or new_width < min_size:
77
+ if original_width <= original_height:
78
+ new_width = min_size
79
+ new_height = int(min_size * original_height / original_width)
80
+ else:
81
+ new_height = min_size
82
+ new_width = int(min_size * original_width / original_height)
83
+
84
+ crop_width = min(max_size, (new_width // step) * step)
85
+ crop_height = min(max_size, (new_height // step) * step)
86
+
87
+ crop_width = max(min_size, crop_width)
88
+ crop_height = max(min_size, crop_height)
89
+
90
+ if dry_run:
91
+ return crop_width, crop_height
92
+
93
+ img_resized = img.convert("RGB").resize((new_width, new_height), Image.LANCZOS)
94
+
95
+ top = (new_height - crop_height) // 3
96
+ left = 0
97
+
98
+ img_cropped = img_resized.crop((left, top, left + crop_width, top + crop_height))
99
+
100
+ final_width, final_height = img_cropped.size
101
+
102
+ img_tensor = ToTensor()(img_cropped)
103
+ img_tensor = Normalize(mean=[0.5]*3, std=[0.5]*3)(img_tensor)
104
+ return img_tensor, img_cropped, final_width, final_height
105
+
106
+ return transform
107
+
108
+ # ---------------- 4️⃣ Функции обработки ----------------
109
+ def clean_label(label):
110
+ label = label.replace("Image 1","").replace("Image 2","").replace("Image 3","").replace("Image 4","")
111
+ label = label.replace("The image depicts ","").replace("The image presents ","")
112
+ label = label.replace("The image features ","").replace("The image portrays ","").replace("The image is ","").strip()
113
+ if label.startswith("."):
114
+ label = label[1:].lstrip()
115
+ return label
116
+
117
+ def process_labels_for_guidance(original_labels, prob_to_make_empty=0.01):
118
+ labels_for_model = []
119
+ labels_for_logging = []
120
+
121
+ for label in original_labels:
122
+ if random.random() < prob_to_make_empty:
123
+ labels_for_model.append("")
124
+ labels_for_logging.append(f"zero: {label}")
125
+ else:
126
+ labels_for_model.append(label)
127
+ labels_for_logging.append(label)
128
+
129
+ return labels_for_model, labels_for_logging
130
+
131
+ def encode_to_latents(images, texts):
132
+ transform = get_image_transform(min_size, max_size, step)
133
+
134
+ transformed_tensors = []
135
+ widths, heights = [], []
136
+
137
+ for img in images:
138
+ try:
139
+ t_img, _, w, h = transform(img)
140
+ transformed_tensors.append(t_img)
141
+ widths.append(w)
142
+ heights.append(h)
143
+ except Exception as e:
144
+ print(f"Ошибка трансформации: {e}")
145
+
146
+ if not transformed_tensors:
147
+ return None
148
+
149
+ batch_tensor = torch.stack(transformed_tensors).to(device, dtype)
150
+
151
+ if batch_tensor.ndim==4:
152
+ batch_tensor = batch_tensor.unsqueeze(2)
153
+
154
+ with torch.no_grad():
155
+ posteriors = vae.encode(batch_tensor).latent_dist.mode()
156
+ if mean is not None and std is not None:
157
+ posteriors = (posteriors - latents_mean) / latents_std
158
+ posteriors = (posteriors - shift_factor) / scaling_factor
159
+
160
+ #latents_np = posteriors.cpu().numpy()
161
+ latents_np = posteriors.squeeze(2).cpu().numpy()
162
+
163
+ text_labels = [clean_label(text) for text in texts]
164
+ _, text_labels = process_labels_for_guidance(text_labels, empty_share)
165
+
166
+ return {
167
+ "vae": latents_np,
168
+ "text": text_labels,
169
+ "width": widths,
170
+ "height": heights
171
+ }
172
+
173
+ # ---------------- 5️⃣ Обработка папки ----------------
174
+ def process_folder(folder_path, limit=None):
175
+ image_paths, text_paths, width, height = [], [], [], []
176
+ transform = get_image_transform(min_size, max_size, step)
177
+
178
+ for root, _, files in os.walk(folder_path):
179
+ for filename in files:
180
+ if filename.lower().endswith((".jpg",".jpeg",".png",".webp")):
181
+ image_path = os.path.join(root, filename)
182
+ try:
183
+ img = Image.open(image_path)
184
+ except:
185
+ continue
186
+
187
+ w,h = transform(img, dry_run=True)
188
+ text_path = os.path.splitext(image_path)[0]+".txt"
189
+
190
+ if os.path.exists(text_path):
191
+ image_paths.append(image_path)
192
+ text_paths.append(text_path)
193
+ width.append(w)
194
+ height.append(h)
195
+
196
+ print(f"Найдено {len(image_paths)} изображений")
197
+ return image_paths, text_paths, width, height
198
+
199
+ def process_in_chunks(image_paths, text_paths, width, height, chunk_size=5000, batch_size=1):
200
+ total_files = len(image_paths)
201
+ start_time = time.time()
202
+
203
+ for chunk_idx, start in enumerate(range(0,total_files,chunk_size),1):
204
+ end = min(start+chunk_size,total_files)
205
+
206
+ chunk_image_paths = image_paths[start:end]
207
+ chunk_text_paths = text_paths[start:end]
208
+ chunk_widths = width[start:end]
209
+ chunk_heights = height[start:end]
210
+
211
+ chunk_texts = []
212
+ for text_path in chunk_text_paths:
213
+ try:
214
+ with open(text_path,'r',encoding='utf-8') as f:
215
+ chunk_texts.append(f.read().strip())
216
+ except:
217
+ chunk_texts.append("")
218
+
219
+ size_groups = {}
220
+ for i in range(len(chunk_image_paths)):
221
+ key=(chunk_widths[i],chunk_heights[i])
222
+ size_groups.setdefault(key,{"image_paths":[],"texts":[]})
223
+ size_groups[key]["image_paths"].append(chunk_image_paths[i])
224
+ size_groups[key]["texts"].append(chunk_texts[i])
225
+
226
+ for size_key,group_data in size_groups.items():
227
+ group_dataset = Dataset.from_dict(group_data)
228
+
229
+ processed_group = group_dataset.map(
230
+ lambda ex: encode_to_latents(
231
+ [Image.open(p) for p in ex["image_paths"]],
232
+ #[Image.open(p).convert("RGB") for p in ex["image_paths"]], # <--- Добавил .convert("RGB"), чтобы картинка загрузилась в память
233
+ ex["texts"]
234
+ ),
235
+ batched=True,
236
+ batch_size=batch_size,
237
+ )
238
+
239
+ # --- NEW: уникальный путь ---
240
+ group_save_path = f"{save_path}_temp/chunk_{chunk_idx}_{size_key[0]}x{size_key[1]}_proc_{process_index}_"
241
+ # --- END NEW ---
242
+
243
+ processed_group.save_to_disk(group_save_path)
244
+ clear_cuda_memory()
245
+
246
+ # ---------------- 7️⃣ Объединение ----------------
247
+ def combine_chunks(temp_path, final_path):
248
+ chunks = sorted([
249
+ os.path.join(temp_path,d)
250
+ for d in os.listdir(temp_path)
251
+ if "chunk_" in d
252
+ ])
253
+
254
+ datasets = [load_from_disk(c) for c in chunks]
255
+ combined = concatenate_datasets(datasets)
256
+ combined.save_to_disk(final_path)
257
+
258
+ print("✅ Сохранено")
259
+
260
+ # ---------------- MAIN ----------------
261
+ temp_path = f"{save_path}_temp"
262
+ os.makedirs(temp_path, exist_ok=True)
263
+
264
+ image_paths, text_paths, width, height = process_folder(folder_path,limit)
265
+
266
+ # сортировка
267
+ sorted_indices = sorted(range(len(width)), key=lambda i:(width[i],height[i]))
268
+ image_paths = [image_paths[i] for i in sorted_indices]
269
+ text_paths = [text_paths[i] for i in sorted_indices]
270
+ width = [width[i] for i in sorted_indices]
271
+ height = [height[i] for i in sorted_indices]
272
+
273
+ # --- shard по GPU ---
274
+ indices = list(range(len(image_paths)))
275
+ indices = indices[process_index::num_processes]
276
+
277
+ image_paths = [image_paths[i] for i in indices]
278
+ text_paths = [text_paths[i] for i in indices]
279
+ width = [width[i] for i in indices]
280
+ height = [height[i] for i in indices]
281
+
282
+ print(f"[GPU {process_index}] обрабатывает {len(image_paths)} файлов")
283
+
284
+ process_in_chunks(image_paths, text_paths, width, height, chunk_size=1000, batch_size=batch_size)
285
+
286
+ accelerator.wait_for_everyone()
287
+
288
+ # --- NEW: только главный процесс ---
289
+ if is_main_process:
290
+ try:
291
+ shutil.rmtree(folder_path)
292
+ except:
293
+ pass
294
+
295
+ combine_chunks(temp_path, save_path)
296
+
297
+ try:
298
+ shutil.rmtree(temp_path)
299
+ except:
300
+ pass
dataset-sdxs2b-640.py ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pip install flash-attn --no-build-isolation
2
+ import torch
3
+ import os
4
+ import gc
5
+ import numpy as np
6
+ import random
7
+ import json
8
+ import shutil
9
+ import time
10
+
11
+ from datasets import Dataset, load_from_disk, concatenate_datasets
12
+ from diffusers import AutoencoderKLQwenImage
13
+ from torchvision.transforms import Resize, ToTensor, Normalize, Compose, InterpolationMode, Lambda
14
+ from transformers import AutoModel, AutoImageProcessor, AutoTokenizer, AutoModelForCausalLM
15
+ from typing import Dict, List, Tuple, Optional, Any
16
+ from PIL import Image
17
+ from tqdm import tqdm
18
+ from datetime import timedelta
19
+ from accelerate import Accelerator
20
+
21
+ accelerator = Accelerator()
22
+ device = accelerator.device
23
+ is_main_process = accelerator.is_main_process
24
+ process_index = accelerator.process_index
25
+ num_processes = accelerator.num_processes
26
+
27
+ # ---------------- 1️⃣ Настройки ----------------
28
+ dtype = torch.float16
29
+ batch_size = 5
30
+ min_size = 320
31
+ max_size = 640
32
+ step = 64
33
+ empty_share = 0.0
34
+ limit = 0
35
+
36
+ folder_path = "/root/datasets/butterfly"
37
+ save_path = "datasets/dsb_640_vae_qwen"
38
+ os.makedirs(save_path, exist_ok=True)
39
+
40
+ def clear_cuda_memory():
41
+ if torch.cuda.is_available():
42
+ used_gb = torch.cuda.max_memory_allocated() / 1024**3
43
+ print(f"[GPU {process_index}] used_gb: {used_gb:.2f} GB")
44
+ torch.cuda.empty_cache()
45
+ gc.collect()
46
+
47
+ # ---------------- 2️⃣ Загрузка моделей ----------------
48
+ def load_models():
49
+ print(f"[GPU {process_index}] Загрузка моделей...")
50
+ vae = AutoencoderKLQwenImage.from_pretrained("vae", torch_dtype=dtype).to(device).eval()
51
+ return vae
52
+
53
+ vae = load_models()
54
+
55
+ shift_factor = getattr(vae.config, "shift_factor", 0.0) or 0.0
56
+ scaling_factor = getattr(vae.config, "scaling_factor", 1.0) or 1.0
57
+
58
+ mean = getattr(vae.config, "latents_mean", None)
59
+ std = getattr(vae.config, "latents_std", None)
60
+ if mean is not None and std is not None:
61
+ latents_std = torch.tensor(std, device=device, dtype=dtype).view(1, len(std), 1, 1, 1)
62
+ latents_mean = torch.tensor(mean, device=device, dtype=dtype).view(1, len(mean), 1, 1, 1)
63
+
64
+ # ---------------- 3️⃣ Трансформации ----------------
65
+ def get_image_transform(min_size=256, max_size=512, step=64):
66
+ def transform(img, dry_run=False):
67
+ original_width, original_height = img.size
68
+
69
+ if original_width >= original_height:
70
+ new_width = max_size
71
+ new_height = int(max_size * original_height / original_width)
72
+ else:
73
+ new_height = max_size
74
+ new_width = int(max_size * original_width / original_height)
75
+
76
+ if new_height < min_size or new_width < min_size:
77
+ if original_width <= original_height:
78
+ new_width = min_size
79
+ new_height = int(min_size * original_height / original_width)
80
+ else:
81
+ new_height = min_size
82
+ new_width = int(min_size * original_width / original_height)
83
+
84
+ crop_width = min(max_size, (new_width // step) * step)
85
+ crop_height = min(max_size, (new_height // step) * step)
86
+
87
+ crop_width = max(min_size, crop_width)
88
+ crop_height = max(min_size, crop_height)
89
+
90
+ if dry_run:
91
+ return crop_width, crop_height
92
+
93
+ img_resized = img.convert("RGB").resize((new_width, new_height), Image.LANCZOS)
94
+
95
+ top = (new_height - crop_height) // 3
96
+ left = 0
97
+
98
+ img_cropped = img_resized.crop((left, top, left + crop_width, top + crop_height))
99
+
100
+ final_width, final_height = img_cropped.size
101
+
102
+ img_tensor = ToTensor()(img_cropped)
103
+ img_tensor = Normalize(mean=[0.5]*3, std=[0.5]*3)(img_tensor)
104
+ return img_tensor, img_cropped, final_width, final_height
105
+
106
+ return transform
107
+
108
+ # ---------------- 4️⃣ Функции обработки ----------------
109
+ def clean_label(label):
110
+ label = label.replace("Image 1","").replace("Image 2","").replace("Image 3","").replace("Image 4","")
111
+ label = label.replace("The image depicts ","").replace("The image presents ","")
112
+ label = label.replace("The image features ","").replace("The image portrays ","").replace("The image is ","").strip()
113
+ if label.startswith("."):
114
+ label = label[1:].lstrip()
115
+ return label
116
+
117
+ def process_labels_for_guidance(original_labels, prob_to_make_empty=0.01):
118
+ labels_for_model = []
119
+ labels_for_logging = []
120
+
121
+ for label in original_labels:
122
+ if random.random() < prob_to_make_empty:
123
+ labels_for_model.append("")
124
+ labels_for_logging.append(f"zero: {label}")
125
+ else:
126
+ labels_for_model.append(label)
127
+ labels_for_logging.append(label)
128
+
129
+ return labels_for_model, labels_for_logging
130
+
131
+ def encode_to_latents(images, texts):
132
+ transform = get_image_transform(min_size, max_size, step)
133
+
134
+ transformed_tensors = []
135
+ widths, heights = [], []
136
+
137
+ for img in images:
138
+ try:
139
+ t_img, _, w, h = transform(img)
140
+ transformed_tensors.append(t_img)
141
+ widths.append(w)
142
+ heights.append(h)
143
+ except Exception as e:
144
+ print(f"Ошибка трансформации: {e}")
145
+
146
+ if not transformed_tensors:
147
+ return None
148
+
149
+ batch_tensor = torch.stack(transformed_tensors).to(device, dtype)
150
+
151
+ if batch_tensor.ndim==4:
152
+ batch_tensor = batch_tensor.unsqueeze(2)
153
+
154
+ with torch.no_grad():
155
+ posteriors = vae.encode(batch_tensor).latent_dist.mode()
156
+ if mean is not None and std is not None:
157
+ posteriors = (posteriors - latents_mean) / latents_std
158
+ posteriors = (posteriors - shift_factor) / scaling_factor
159
+
160
+ #latents_np = posteriors.cpu().numpy()
161
+ latents_np = posteriors.squeeze(2).cpu().numpy()
162
+
163
+ text_labels = [clean_label(text) for text in texts]
164
+ _, text_labels = process_labels_for_guidance(text_labels, empty_share)
165
+
166
+ return {
167
+ "vae": latents_np,
168
+ "text": text_labels,
169
+ "width": widths,
170
+ "height": heights
171
+ }
172
+
173
+ # ---------------- 5️⃣ Обработка папки ----------------
174
+ def process_folder(folder_path, limit=None):
175
+ image_paths, text_paths, width, height = [], [], [], []
176
+ transform = get_image_transform(min_size, max_size, step)
177
+
178
+ for root, _, files in os.walk(folder_path):
179
+ for filename in files:
180
+ if filename.lower().endswith((".jpg",".jpeg",".png",".webp")):
181
+ image_path = os.path.join(root, filename)
182
+ try:
183
+ img = Image.open(image_path)
184
+ except:
185
+ continue
186
+
187
+ w,h = transform(img, dry_run=True)
188
+ text_path = os.path.splitext(image_path)[0]+".txt"
189
+
190
+ if os.path.exists(text_path):
191
+ image_paths.append(image_path)
192
+ text_paths.append(text_path)
193
+ width.append(w)
194
+ height.append(h)
195
+
196
+ print(f"Найдено {len(image_paths)} изображений")
197
+ return image_paths, text_paths, width, height
198
+
199
+ def process_in_chunks(image_paths, text_paths, width, height, chunk_size=5000, batch_size=1):
200
+ total_files = len(image_paths)
201
+ start_time = time.time()
202
+
203
+ for chunk_idx, start in enumerate(range(0,total_files,chunk_size),1):
204
+ end = min(start+chunk_size,total_files)
205
+
206
+ chunk_image_paths = image_paths[start:end]
207
+ chunk_text_paths = text_paths[start:end]
208
+ chunk_widths = width[start:end]
209
+ chunk_heights = height[start:end]
210
+
211
+ chunk_texts = []
212
+ for text_path in chunk_text_paths:
213
+ try:
214
+ with open(text_path,'r',encoding='utf-8') as f:
215
+ chunk_texts.append(f.read().strip())
216
+ except:
217
+ chunk_texts.append("")
218
+
219
+ size_groups = {}
220
+ for i in range(len(chunk_image_paths)):
221
+ key=(chunk_widths[i],chunk_heights[i])
222
+ size_groups.setdefault(key,{"image_paths":[],"texts":[]})
223
+ size_groups[key]["image_paths"].append(chunk_image_paths[i])
224
+ size_groups[key]["texts"].append(chunk_texts[i])
225
+
226
+ for size_key,group_data in size_groups.items():
227
+ group_dataset = Dataset.from_dict(group_data)
228
+
229
+ processed_group = group_dataset.map(
230
+ lambda ex: encode_to_latents(
231
+ [Image.open(p) for p in ex["image_paths"]],
232
+ #[Image.open(p).convert("RGB") for p in ex["image_paths"]], # <--- Добавил .convert("RGB"), чтобы картинка загрузилась в память
233
+ ex["texts"]
234
+ ),
235
+ batched=True,
236
+ batch_size=batch_size,
237
+ )
238
+
239
+ # --- NEW: уникальный путь ---
240
+ group_save_path = f"{save_path}_temp/chunk_{chunk_idx}_{size_key[0]}x{size_key[1]}_proc_{process_index}_"
241
+ # --- END NEW ---
242
+
243
+ processed_group.save_to_disk(group_save_path)
244
+ clear_cuda_memory()
245
+
246
+ # ---------------- 7️⃣ Объединение ----------------
247
+ def combine_chunks(temp_path, final_path):
248
+ chunks = sorted([
249
+ os.path.join(temp_path,d)
250
+ for d in os.listdir(temp_path)
251
+ if "chunk_" in d
252
+ ])
253
+
254
+ datasets = [load_from_disk(c) for c in chunks]
255
+ combined = concatenate_datasets(datasets)
256
+ combined.save_to_disk(final_path)
257
+
258
+ print("✅ Сохранено")
259
+
260
+ # ---------------- MAIN ----------------
261
+ temp_path = f"{save_path}_temp"
262
+ os.makedirs(temp_path, exist_ok=True)
263
+
264
+ image_paths, text_paths, width, height = process_folder(folder_path,limit)
265
+
266
+ # сортировка
267
+ sorted_indices = sorted(range(len(width)), key=lambda i:(width[i],height[i]))
268
+ image_paths = [image_paths[i] for i in sorted_indices]
269
+ text_paths = [text_paths[i] for i in sorted_indices]
270
+ width = [width[i] for i in sorted_indices]
271
+ height = [height[i] for i in sorted_indices]
272
+
273
+ # --- shard по GPU ---
274
+ indices = list(range(len(image_paths)))
275
+ indices = indices[process_index::num_processes]
276
+
277
+ image_paths = [image_paths[i] for i in indices]
278
+ text_paths = [text_paths[i] for i in indices]
279
+ width = [width[i] for i in indices]
280
+ height = [height[i] for i in indices]
281
+
282
+ print(f"[GPU {process_index}] обрабатывает {len(image_paths)} файлов")
283
+
284
+ process_in_chunks(image_paths, text_paths, width, height, chunk_size=1000, batch_size=batch_size)
285
+
dataset_sample.ipynb CHANGED
@@ -1,170 +1,3 @@
1
- {
2
- "cells": [
3
- {
4
- "cell_type": "code",
5
- "execution_count": 3,
6
- "id": "9c312df2-cb57-44f6-af54-3af6ab8f962f",
7
- "metadata": {},
8
- "outputs": [
9
- {
10
- "ename": "ModuleNotFoundError",
11
- "evalue": "No module named 'numpy'",
12
- "output_type": "error",
13
- "traceback": [
14
- "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
15
- "\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)",
16
- "Cell \u001b[0;32mIn[3], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;66;03m#from datasets import load_from_disk\u001b[39;00m\n\u001b[0;32m----> 2\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mnumpy\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m \u001b[38;5;21;01mnp\u001b[39;00m\n\u001b[1;32m 3\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mtorch\u001b[39;00m\n\u001b[1;32m 4\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mPIL\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m Image\n",
17
- "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'numpy'"
18
- ]
19
- }
20
- ],
21
- "source": [
22
- "from datasets import load_from_disk\n",
23
- "import numpy as np\n",
24
- "import torch\n",
25
- "from PIL import Image\n",
26
- "from collections import defaultdict\n",
27
- "from diffusers import AutoencoderKLQwenImage\n",
28
- "import gc\n",
29
- "\n",
30
- "def analyze_dataset_by_size(dataset_path):\n",
31
- " \"\"\"\n",
32
- " Группирует датасет по размерам изображений и выводит базовую информацию.\n",
33
- " \"\"\"\n",
34
- " # Настройка устройства и типа данных\n",
35
- " dtype = torch.float16\n",
36
- " device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
37
- " \n",
38
- " # Загрузка VAE модели\n",
39
- " print(\"Загрузка VAE модели...\")\n",
40
- " vae = AutoencoderKLQwenImage.from_pretrained(\"vae\",torch_dtype=dtype).to(device).eval()\n",
41
- " shift_factor = getattr(vae.config, \"shift_factor\", 0.0)\n",
42
- " if shift_factor is None:\n",
43
- " shift_factor = 0.0\n",
44
- " \n",
45
- " scaling_factor = getattr(vae.config, \"scaling_factor\", 1.0)\n",
46
- " if scaling_factor is None:\n",
47
- " scaling_factor = 1.0\n",
48
- " \n",
49
- " mean = getattr(vae.config, \"latents_mean\", None)\n",
50
- " std = getattr(vae.config, \"latents_std\", None)\n",
51
- " if mean is not None and std is not None:\n",
52
- " latents_std = torch.tensor(std, device=device, dtype=dtype).view(1, len(std), 1, 1)\n",
53
- " latents_mean = torch.tensor(mean, device=device, dtype=dtype).view(1, len(mean), 1, 1)\n",
54
- " \n",
55
- " # Загружаем датасет\n",
56
- " print(f\"Загрузка датасета из {dataset_path}...\")\n",
57
- " dataset = load_from_disk(dataset_path)\n",
58
- "\n",
59
- " print(f\"Осталось примеров после фильтрации: {len(dataset)}\")\n",
60
- " \n",
61
- " # Группируем примеры по размерам\n",
62
- " print(\"\\nГруппировка примеров по размерам...\")\n",
63
- " size_to_indices = defaultdict(list)\n",
64
- " \n",
65
- " # Собираем примеры с одинаковыми размерами\n",
66
- " # Собираем примеры с одинаковыми размерами (оптимизированная версия)\n",
67
- " widths = dataset[\"width\"]\n",
68
- " heights = dataset[\"height\"]\n",
69
- " for i, (w, h) in enumerate(zip(widths, heights)):\n",
70
- " size_to_indices[(w, h)].append(i)\n",
71
- " \n",
72
- " # Сортируем размеры по количеству примеров\n",
73
- " print(\"\\nСортируем...\")\n",
74
- " size_stats = [(size, len(indices)) for size, indices in size_to_indices.items()]\n",
75
- " size_stats.sort(key=lambda x: x[1], reverse=True)\n",
76
- " \n",
77
- " # Выводим информацию о каждой группе и показываем первый пример\n",
78
- " for size, count in size_stats:\n",
79
- " width, height = size\n",
80
- " first_idx = size_to_indices[size][1]\n",
81
- " example = dataset[first_idx]\n",
82
- " \n",
83
- " print(f\"\\n--- Батч {width}x{height}: {count} примеров ---\")\n",
84
- " \n",
85
- " # Декодируем латентное представление для первого примера\n",
86
- " latent = torch.tensor(example[\"vae\"], dtype=dtype).unsqueeze(0).to(device)\n",
87
- " \n",
88
- " # 1. Снова обманываем VAE, превращая картинку в \"видео из 1 кадра\" [B, C, 1, H, W]\n",
89
- " if latent.ndim == 4:\n",
90
- " latent = latent.unsqueeze(2)\n",
91
- " \n",
92
- " with torch.no_grad():\n",
93
- " if latents_mean is not None and latents_std is not None:\n",
94
- " latent = latent * latents_std + latents_mean\n",
95
- " \n",
96
- " print(f\"Min of latent_for_vae: {latent.min()}\")\n",
97
- " print(f\"Max of latent_for_vae: {latent.max()}\")\n",
98
- " print(f\"Mean of latent_for_vae: {latent.mean()}\")\n",
99
- " print(f\"Std: {latent.std().item():.4f}\")\n",
100
- " if torch.isnan(latent).any() or torch.isinf(latent).any():\n",
101
- " print(\"WARNING: Raw latents contain NaN or Inf values!\")\n",
102
- " \n",
103
- " reconstructed_image = vae.decode(latent).sample\n",
104
- " \n",
105
- " # 2. Вытаскиваем обычную 3D-картинку [C, H, W] из 5D-видеотензора\n",
106
- " if reconstructed_image.ndim == 5:\n",
107
- " # Берем нулевой батч, все каналы, нулевой кадр, всю высоту и ширину\n",
108
- " img_tensor = reconstructed_image[0, :, 0, :, :] \n",
109
- " else:\n",
110
- " img_tensor = reconstructed_image.squeeze(0) # На всякий случай, если VAE вернул 4D\n",
111
- " \n",
112
- " img_array = img_tensor.cpu().numpy()\n",
113
- " img_array = np.transpose(img_array, (1, 2, 0))\n",
114
- " img_array = (img_array + 1) / 2 # Нормализация к [0, 1]\n",
115
- " img_array = np.clip(img_array * 255, 0, 255).astype(np.uint8) # Преобразуем в uint8 для PIL\n",
116
- " \n",
117
- " # Создаем PIL изображение из массива\n",
118
- " pil_image = Image.fromarray(img_array)\n",
119
- " print(f\"Текст: {example['text']}\")\n",
120
- " print(f\"Ключи: {', '.join(example.keys())}\")\n",
121
- " print(f\"latent: {latent.shape}\")\n",
122
- " pil_image.save(\"1.jpg\")\n",
123
- " \n",
124
- " # Очистка памяти\n",
125
- " if torch.cuda.is_available():\n",
126
- " torch.cuda.empty_cache()\n",
127
- " gc.collect()\n",
128
- " \n",
129
- " return size_to_indices # Возвращаем словарь с индексами по группам\n",
130
- "\n",
131
- "# Использование\n",
132
- "if __name__ == \"__main__\":\n",
133
- " # Путь к датасету\n",
134
- " save_path = \"datasets/ds234_640_vae_qwen\"\n",
135
- " \n",
136
- " # Анализ датасета\n",
137
- " size_groups = analyze_dataset_by_size(save_path)"
138
- ]
139
- },
140
- {
141
- "cell_type": "code",
142
- "execution_count": null,
143
- "id": "74a5d11d-369f-4f25-9ee0-31d3bccd0254",
144
- "metadata": {},
145
- "outputs": [],
146
- "source": []
147
- }
148
- ],
149
- "metadata": {
150
- "kernelspec": {
151
- "display_name": "Python 3 (ipykernel)",
152
- "language": "python",
153
- "name": "python3"
154
- },
155
- "language_info": {
156
- "codemirror_mode": {
157
- "name": "ipython",
158
- "version": 3
159
- },
160
- "file_extension": ".py",
161
- "mimetype": "text/x-python",
162
- "name": "python",
163
- "nbconvert_exporter": "python",
164
- "pygments_lexer": "ipython3",
165
- "version": "3.12.3"
166
- }
167
- },
168
- "nbformat": 4,
169
- "nbformat_minor": 5
170
- }
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:774dc5b6f2f55e8b4e925e5ba984f73b18e2c096b6c1df4bfe0075aa51a56258
3
+ size 8190
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pipeline_sdxs.py CHANGED
@@ -14,12 +14,10 @@ class SdxsPipelineOutput(BaseOutput):
14
  prompt: Optional[Union[str, List[str]]] = None
15
 
16
  class SdxsPipeline(DiffusionPipeline):
17
- # Cosmos требует 512 токенов
18
- MAX_TEXT_TOKENS = 512
19
 
20
  def __init__(self, vae, text_encoder, tokenizer, transformer, scheduler):
21
  super().__init__()
22
- # Регистрируем модули (с Qwen)
23
  self.register_modules(
24
  vae=vae,
25
  text_encoder=text_encoder,
@@ -28,62 +26,36 @@ class SdxsPipeline(DiffusionPipeline):
28
  scheduler=scheduler
29
  )
30
 
31
- self.vae_scale_factor = getattr(self.vae.config, "spatial_compression_ratio", 8)
32
- if hasattr(self.vae.config, "block_out_channels"):
33
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
34
 
35
- # Загружаем mean и std для VAE (Cosmos-style)
36
- mean = getattr(self.vae.config, "latents_mean", None)
37
- std = getattr(self.vae.config, "latents_std", None)
38
- if mean is not None and std is not None:
39
- self.vae_latents_mean = torch.tensor(mean).view(1, len(mean), 1, 1, 1)
40
- # Внимание: Cosmos использует инвертированный std для декодирования (1.0 / std)
41
- self.vae_latents_std = torch.tensor(std).view(1, len(std), 1, 1, 1)
42
- else:
43
- self.vae_latents_mean = None
44
- self.vae_latents_std = None
45
-
46
- # Регистрируем параметры Cosmos в шедулере (если они еще не там)
47
- if self.scheduler is not None:
48
- self.scheduler.register_to_config(
49
- sigma_max=getattr(self.scheduler.config, "sigma_max", 80.0),
50
- sigma_min=getattr(self.scheduler.config, "sigma_min", 0.002),
51
- sigma_data=getattr(self.scheduler.config, "sigma_data", 1.0),
52
- final_sigmas_type=getattr(self.scheduler.config, "final_sigmas_type", "sigma_min"),
53
- )
54
-
55
- @staticmethod
56
- def _pad_tensor_to_length(tensor: torch.Tensor, target_len: int, dim: int = 1, pad_value: float = 0) -> torch.Tensor:
57
- current_len = tensor.shape[dim]
58
- if current_len >= target_len:
59
- return tensor
60
- pad_size = target_len - current_len
61
- if tensor.dim() == 3:
62
- padding = (0, 0, 0, pad_size, 0, 0)
63
- elif tensor.dim() == 2:
64
- padding = (0, pad_size, 0, 0)
65
- else:
66
- raise ValueError(f"Unsupported tensor dimension: {tensor.dim()}")
67
- return torch.nn.functional.pad(tensor, padding, value=pad_value)
68
-
69
- @torch.no_grad()
70
  def refine_prompts(
71
  self,
72
  prompts: Union[str, List[str]],
73
  system_prompt: Optional[str] = None,
74
  temperature: float = 0.7
75
  ) -> List[str]:
76
- """Refines a list of prompts using the Text Encoder (LLM)."""
 
 
 
 
 
 
 
 
 
 
77
  device = self.device
78
 
 
79
  if system_prompt is None:
80
  system_prompt = (
81
  "You are a skilled text-to-image prompt engineer whose sole function is to transform "
82
- "the user's input into an aesthetically optimized, detailed, and visually descriptive two-sentence output. "
83
- "**The primary subject MUST be the main focus of the revised prompt "
84
- "and MUST be described in rich detail within the first sentence.** "
85
  "Output **only** the final revised prompt, with absolutely no commentary. "
86
- "Don't use cliches like warm, soft, vibrant, wildflowers. Be creative. User input prompt: "
87
  )
88
 
89
  pad_id = getattr(self.text_encoder.config, "pad_token_id", None) or \
@@ -93,6 +65,7 @@ class SdxsPipeline(DiffusionPipeline):
93
  refined_list = []
94
 
95
  for p in prompts_list:
 
96
  full_text = system_prompt + p
97
  messages = [{"role": "user", "content": [{"type": "text", "text": full_text}]}]
98
 
@@ -120,7 +93,6 @@ class SdxsPipeline(DiffusionPipeline):
120
 
121
  @torch.no_grad()
122
  def encode_text(self, text: Union[str, List[str]]) -> Tuple[torch.Tensor, torch.Tensor]:
123
- """Qwen-specific text encoding (using chat_template and hidden_states[-2])"""
124
  device = self.device
125
  dtype = self.transformer.dtype
126
  if text is None: text = ""
@@ -128,221 +100,148 @@ class SdxsPipeline(DiffusionPipeline):
128
 
129
  formatted_prompts = []
130
  for t in text:
 
131
  messages = [{"role": "user", "content": [{"type": "text", "text": t}]}]
132
  formatted_prompts.append(self.tokenizer.apply_chat_template(messages, add_generation_prompt=False, tokenize=False))
133
 
134
- toks = self.tokenizer(formatted_prompts, padding="max_length", max_length=self.MAX_TEXT_TOKENS, truncation=True, return_tensors="pt").to(device)
135
- outputs = self.text_encoder(input_ids=toks.input_ids, attention_mask=toks.attention_mask, output_hidden_states=True)
 
 
 
 
 
136
 
137
- # Берем предпоследний слой эмбеддингов, как того требуют современные пайплайны
138
- last_hidden = outputs.hidden_states[-2]
139
-
140
- return last_hidden.to(dtype=dtype), toks.attention_mask.to(dtype=torch.int64)
141
-
142
- @torch.no_grad()
143
- def image_upscale(self, image: Union[str, Image.Image, List[Union[str, Image.Image]]], batch_size: int = 1) -> List[Image.Image]:
144
- images = [image] if isinstance(image, (str, Image.Image)) else image
145
 
146
- batch_data = []
147
- for img in images:
148
- if isinstance(img, str): img = Image.open(img)
149
- if img.mode == "RGBA":
150
- img = Image.alpha_composite(Image.new("RGBA", img.size, (255, 255, 255)), img)
151
- img = img.convert("RGB")
152
-
153
- w, h = img.size
154
- pw, ph = (8 - w % 8) % 8, (8 - h % 8) % 8
155
- if pw or ph:
156
- padded = Image.new("RGB", (w + pw, h + ph), (255, 255, 255))
157
- padded.paste(img)
158
- img = padded
159
-
160
- t = torch.from_numpy(np.array(img).astype(np.float32) / 127.5 - 1.0).permute(2, 0, 1)
161
- batch_data.append((t.to(self.device, torch.float16), w, h))
162
-
163
- unique_shapes = {t.shape for t, _, _ in batch_data}
164
- step = batch_size if len(unique_shapes) == 1 else 1
165
 
166
- output_images = []
167
- for i in range(0, len(batch_data), step):
168
- chunk = batch_data[i : i + step]
169
- tensors = torch.stack([c[0] for c in chunk]).unsqueeze(2)
170
-
171
- latents = self.vae.encode(tensors).latent_dist.mean
172
- decoded = self.vae.decode(latents.to(self.vae.dtype))[0]
173
-
174
- if decoded.ndim == 5:
175
- decoded = decoded.squeeze(2)
176
-
177
- decoded = (decoded.clamp(-1, 1) + 1) / 2
178
- for j, tensor in enumerate(decoded):
179
- w, h = chunk[j][1], chunk[j][2]
180
- arr = tensor.cpu().permute(1, 2, 0).float().numpy()
181
- arr = arr[:h * 2, :w * 2]
182
- output_images.append(Image.fromarray((arr * 255).astype("uint8")))
183
-
184
- return output_images
185
-
186
  @torch.no_grad()
187
  def __call__(
188
  self,
189
  prompt: Optional[Union[str, List[str]]] = None,
190
  negative_prompt: Optional[Union[str, List[str]]] = None,
191
- prompt_embeds: Optional[torch.Tensor] = None,
192
- negative_prompt_embeds: Optional[torch.Tensor] = None,
193
- prompt_attention_mask: Optional[torch.Tensor] = None,
194
- negative_prompt_attention_mask: Optional[torch.Tensor] = None,
195
- latents: Optional[torch.Tensor] = None,
196
- height: int = 1024,
197
- width: int = 1024,
198
  num_inference_steps: int = 40,
199
  guidance_scale: float = 4.0,
200
- generator: Optional[torch.Generator] = None,
201
  seed: Optional[int] = None,
202
  output_type: str = "pil",
203
  return_dict: bool = True,
204
- **kwargs,
205
  ):
206
  device = self.device
207
  dtype = self.transformer.dtype
208
-
209
- if generator is None and seed is not None:
210
  generator = torch.Generator(device=device).manual_seed(seed)
211
-
 
 
212
  do_classifier_free_guidance = guidance_scale > 1.0
213
 
214
- # 1. Encode Positive
215
- if prompt_embeds is None:
216
- if prompt is None: raise ValueError("`prompt` or `prompt_embeds` required.")
217
- prompt_embeds, prompt_attention_mask = self.encode_text(prompt)
218
- prompt_embeds = prompt_embeds.to(device=device, dtype=dtype)
219
- prompt_attention_mask = prompt_attention_mask.to(device=device, dtype=torch.int64)
220
  batch_size = prompt_embeds.shape[0]
221
 
222
- # 2. Encode Negative
223
  if do_classifier_free_guidance:
224
- if negative_prompt_embeds is None:
225
- neg_text = negative_prompt if negative_prompt is not None else ("" if isinstance(prompt, str) else [""] * len(prompt))
226
- negative_prompt_embeds, negative_prompt_attention_mask = self.encode_text(neg_text)
227
 
228
- negative_prompt_embeds = negative_prompt_embeds.to(device=device, dtype=dtype)
229
- negative_prompt_attention_mask = negative_prompt_attention_mask.to(device=device, dtype=torch.int64)
230
-
231
- if negative_prompt_embeds.shape[0] != batch_size:
232
- negative_prompt_embeds = negative_prompt_embeds.repeat(batch_size, 1, 1)
233
- negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(batch_size, 1)
234
-
235
- max_len = max(prompt_embeds.shape[1], negative_prompt_embeds.shape[1])
236
- prompt_embeds = self._pad_tensor_to_length(prompt_embeds, max_len, dim=1, pad_value=0)
237
- negative_prompt_embeds = self._pad_tensor_to_length(negative_prompt_embeds, max_len, dim=1, pad_value=0)
238
- prompt_attention_mask = self._pad_tensor_to_length(prompt_attention_mask, max_len, dim=1, pad_value=0)
239
- negative_prompt_attention_mask = self._pad_tensor_to_length(negative_prompt_attention_mask, max_len, dim=1, pad_value=0)
240
-
241
- text_embeddings = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
242
  else:
243
  text_embeddings = prompt_embeds
244
-
245
- # 3. Prepare Timesteps (Cosmos specific schedule)
246
- sigmas_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64
247
- sigmas = torch.linspace(0, 1, num_inference_steps, dtype=sigmas_dtype)
248
- self.scheduler.set_timesteps(sigmas=sigmas, device=device)
249
- timesteps = self.scheduler.timesteps
250
-
251
- # Защита от деления на ноль на последнем шаге
252
- if self.scheduler.config.get("final_sigmas_type", "zero") == "sigma_min":
253
- self.scheduler.sigmas[-1] = self.scheduler.sigmas[-2]
254
- if self.scheduler.sigmas[-1] == 0.0:
255
- self.scheduler.sigmas[-1] = 1e-4
256
 
257
- # 4. Prepare Latents (Noise)
 
 
 
 
258
  latent_h = height // self.vae_scale_factor
259
  latent_w = width // self.vae_scale_factor
260
  in_channels = self.transformer.config.in_channels
261
- sigma_max = getattr(self.scheduler.config, "sigma_max", 80.0)
262
 
263
- if latents is None:
264
- # Создаем 5D тензор [Batch, Channels, Frames, Height, Width]
265
- latents = torch.randn((batch_size, in_channels, 1, latent_h, latent_w), generator=generator, device=device, dtype=dtype)
266
- latents = latents * sigma_max
267
- else:
268
- latents = latents.to(device=device, dtype=dtype) * sigma_max
 
269
 
270
- # Cosmos Padding Mask
271
- padding_mask = torch.zeros((1, 1, height, width), device=device, dtype=dtype)
272
 
273
- # 5. Denoising Loop (Continuous Flow Math)
274
- for i, t in enumerate(tqdm(timesteps, desc="Sampling")):
275
- current_sigma = self.scheduler.sigmas[i]
276
-
277
- # Защита от деления на 0 при вычислении current_t
278
- if current_sigma == 0.0:
279
- current_sigma = torch.tensor(1e-4, dtype=current_sigma.dtype, device=device)
280
-
281
- current_t = current_sigma / (current_sigma + 1.0)
282
- c_in = 1.0 - current_t
283
- c_skip = 1.0 - current_t
284
- c_out = -current_t
285
 
 
286
  latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
287
- latent_model_input = (latent_model_input * c_in).to(dtype)
288
 
289
- # Трансформер ждет timestep в виде 1D тензора [B]
290
- t_val = float(current_t.item()) if torch.is_tensor(current_t) else float(current_t)
291
- timestep_tensor = torch.tensor(
292
- [t_val],
293
- device=device,
294
- dtype=dtype
295
- ).view(1, 1, 1, 1, 1).expand(latent_model_input.shape[0], 1, 1, 1, 1)
296
 
297
- model_out = self.transformer(
 
 
298
  hidden_states=latent_model_input,
299
- timestep=timestep_tensor,
300
  encoder_hidden_states=text_embeddings,
301
  padding_mask=padding_mask,
302
  return_dict=False,
303
  )[0]
304
-
305
- batched_latents = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
306
- noise_pred = (c_skip * batched_latents + c_out * model_out.float()).to(dtype)
307
-
308
  if do_classifier_free_guidance:
309
- noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
310
- noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
311
-
312
- noise_pred = (latents - noise_pred) / current_sigma
313
- latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
314
-
315
- # 6. Decode
 
 
 
 
316
  if output_type == "latent":
317
- if not return_dict: return (latents, prompt)
318
  return SdxsPipelineOutput(images=latents)
319
-
320
- if getattr(self.vae.config, "latents_std", None) is not None and getattr(self.vae.config, "latents_mean", None) is not None:
321
- sigma_data = getattr(self.scheduler.config, "sigma_data", 1.0)
322
-
323
  l_mean = torch.tensor(self.vae.config.latents_mean).view(1, -1, 1, 1, 1).to(device, dtype)
324
  l_std = torch.tensor(self.vae.config.latents_std).view(1, -1, 1, 1, 1).to(device, dtype)
325
-
326
- # Оригинальная формула: делим на инвертированный std (что равноценно умножению на std)
327
- #latents_std_inv = 1.0 / l_std
328
  latents = latents * l_std + l_mean
329
 
330
- image_output = self.vae.decode(latents.to(self.vae.dtype), return_dict=False)[0]
 
 
331
 
332
  if image_output.ndim == 5:
333
- image_output = image_output.squeeze(2)
334
-
335
  image_output = (image_output.clamp(-1, 1) + 1) / 2
336
  image_np = image_output.cpu().permute(0, 2, 3, 1).float().numpy()
337
-
338
- # На всякий случай вычищаем NaNs
339
- image_np = np.nan_to_num(image_np, nan=0.0, posinf=1.0, neginf=0.0)
340
-
341
  if output_type == "pil":
342
- images = [(Image.fromarray((img * 255).round().astype("uint8"))) for img in image_np]
343
  else:
344
  images = image_np
345
-
346
- if not return_dict:
347
- return (images,)
348
- return SdxsPipelineOutput(images=images)
 
14
  prompt: Optional[Union[str, List[str]]] = None
15
 
16
  class SdxsPipeline(DiffusionPipeline):
17
+ MAX_TEXT_TOKENS = 400 # не Соответствует max_length в обучении
 
18
 
19
  def __init__(self, vae, text_encoder, tokenizer, transformer, scheduler):
20
  super().__init__()
 
21
  self.register_modules(
22
  vae=vae,
23
  text_encoder=text_encoder,
 
26
  scheduler=scheduler
27
  )
28
 
29
+ self.vae_scale_factor = 8
 
 
30
 
31
+
32
+ @torch.no_grad()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  def refine_prompts(
34
  self,
35
  prompts: Union[str, List[str]],
36
  system_prompt: Optional[str] = None,
37
  temperature: float = 0.7
38
  ) -> List[str]:
39
+ """
40
+ Refines a list of prompts using the Text Encoder (LLM).
41
+
42
+ Args:
43
+ prompts: Single prompt string or list of prompts.
44
+ system_prompt: Custom instruction for the LLM. If None, uses default aesthetic enhancer.
45
+ temperature: Sampling temperature for generation.
46
+
47
+ Returns:
48
+ List of refined prompts.
49
+ """
50
  device = self.device
51
 
52
+ # Default system prompt if none provided
53
  if system_prompt is None:
54
  system_prompt = (
55
  "You are a skilled text-to-image prompt engineer whose sole function is to transform "
56
+ "the user's input into an aesthetic, detailed, and visually descriptive three-sentence output. "
 
 
57
  "Output **only** the final revised prompt, with absolutely no commentary. "
58
+ "Don't use cliches like warm, soft, vibrant, wildflowers. User input prompt: "
59
  )
60
 
61
  pad_id = getattr(self.text_encoder.config, "pad_token_id", None) or \
 
65
  refined_list = []
66
 
67
  for p in prompts_list:
68
+ # Prepend system prompt to user input
69
  full_text = system_prompt + p
70
  messages = [{"role": "user", "content": [{"type": "text", "text": full_text}]}]
71
 
 
93
 
94
  @torch.no_grad()
95
  def encode_text(self, text: Union[str, List[str]]) -> Tuple[torch.Tensor, torch.Tensor]:
 
96
  device = self.device
97
  dtype = self.transformer.dtype
98
  if text is None: text = ""
 
100
 
101
  formatted_prompts = []
102
  for t in text:
103
+ # Повторяем логику чат-шаблона из обучения
104
  messages = [{"role": "user", "content": [{"type": "text", "text": t}]}]
105
  formatted_prompts.append(self.tokenizer.apply_chat_template(messages, add_generation_prompt=False, tokenize=False))
106
 
107
+ toks = self.tokenizer(
108
+ formatted_prompts,
109
+ padding="max_length",
110
+ max_length=self.MAX_TEXT_TOKENS,
111
+ truncation=True,
112
+ return_tensors="pt"
113
+ ).to(device)
114
 
115
+ outputs = self.text_encoder(
116
+ input_ids=toks.input_ids,
117
+ attention_mask=toks.attention_mask,
118
+ output_hidden_states=True
119
+ )
 
 
 
120
 
121
+ # Берем предпоследний слой (-2) как в обучении
122
+ last_hidden = outputs.hidden_states[-2].to(dtype=dtype)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
 
124
+ # Обнуляем паддинги для честности (как в обучении)
125
+ lengths = toks.attention_mask.sum(dim=1)
126
+ for i, length in enumerate(lengths):
127
+ last_hidden[i, length:] = 0
128
+
129
+ return last_hidden, toks.attention_mask.to(dtype=torch.int64)
130
+
 
 
 
 
 
 
 
 
 
 
 
 
 
131
  @torch.no_grad()
132
  def __call__(
133
  self,
134
  prompt: Optional[Union[str, List[str]]] = None,
135
  negative_prompt: Optional[Union[str, List[str]]] = None,
136
+ height: int = 1152,
137
+ width: int = 768,
 
 
 
 
 
138
  num_inference_steps: int = 40,
139
  guidance_scale: float = 4.0,
 
140
  seed: Optional[int] = None,
141
  output_type: str = "pil",
142
  return_dict: bool = True,
 
143
  ):
144
  device = self.device
145
  dtype = self.transformer.dtype
146
+
147
+ if seed is not None:
148
  generator = torch.Generator(device=device).manual_seed(seed)
149
+ else:
150
+ generator = None
151
+
152
  do_classifier_free_guidance = guidance_scale > 1.0
153
 
154
+ # 1. Encode Prompts
155
+ prompt_embeds, prompt_mask = self.encode_text(prompt)
 
 
 
 
156
  batch_size = prompt_embeds.shape[0]
157
 
 
158
  if do_classifier_free_guidance:
159
+ neg_text = negative_prompt if negative_prompt is not None else ([""] * batch_size)
160
+ neg_embeds, neg_mask = self.encode_text(neg_text)
 
161
 
162
+ # Конкатенация для батч-генерации (uncond + cond)
163
+ text_embeddings = torch.cat([neg_embeds, prompt_embeds], dim=0)
164
+ # В вашем обучении padding_mask в модель передавался как нули,
165
+ # но внутри трансформера обычно используется encoder_attention_mask
 
 
 
 
 
 
 
 
 
 
166
  else:
167
  text_embeddings = prompt_embeds
 
 
 
 
 
 
 
 
 
 
 
 
168
 
169
+ # 2. Prepare Timesteps (Flow Matching: от 1.0 к 0.0)
170
+ # В обучении t=1 был шумом, t=0 — данными.
171
+ timesteps = torch.linspace(1.0, 0.0, num_inference_steps + 1, device=device, dtype=dtype)
172
+
173
+ # 3. Prepare Latents
174
  latent_h = height // self.vae_scale_factor
175
  latent_w = width // self.vae_scale_factor
176
  in_channels = self.transformer.config.in_channels
 
177
 
178
+ # В Flow Matching начальный шум имеет стандартное отклонение 1.0
179
+ latents = torch.randn(
180
+ (batch_size, in_channels, 1, latent_h, latent_w),
181
+ generator=generator,
182
+ device=device,
183
+ dtype=dtype
184
+ )
185
 
186
+ # Пустая маска как в обучении
187
+ padding_mask = torch.zeros((1, 1, latent_h, latent_w), device=device, dtype=dtype)
188
 
189
+ # 4. Denoising Loop (Euler Method)
190
+ for i in tqdm(range(num_inference_steps), desc="Sampling"):
191
+ t_curr = timesteps[i]
192
+ t_next = timesteps[i+1]
 
 
 
 
 
 
 
 
193
 
194
+ # Подготовка входа (CFG)
195
  latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
 
196
 
197
+ # Модель обучалась на t.flatten(), передаем как вектор [B]
198
+ t_vec = t_curr.expand(latent_model_input.shape[0])
 
 
 
 
 
199
 
200
+ # Предсказание "скорости" (v)
201
+ # Т.к. в обучении target = noise - clean, модель предсказывает направление к шуму
202
+ model_output = self.transformer(
203
  hidden_states=latent_model_input,
204
+ timestep=t_vec,
205
  encoder_hidden_states=text_embeddings,
206
  padding_mask=padding_mask,
207
  return_dict=False,
208
  )[0]
209
+
 
 
 
210
  if do_classifier_free_guidance:
211
+ v_uncond, v_cond = model_output.chunk(2)
212
+ v_final = v_uncond + guidance_scale * (v_cond - v_uncond)
213
+ else:
214
+ v_final = model_output
215
+
216
+ # Euler шаг: x_{t-1} = x_t + (t_next - t_curr) * v
217
+ # Поскольку t идет от 1 к 0, (t_next - t_curr) будет отрицательным,
218
+ # что правильно двигает нас от шума к данным.
219
+ latents = latents + (t_next - t_curr) * v_final
220
+
221
+ # 5. Decode
222
  if output_type == "latent":
 
223
  return SdxsPipelineOutput(images=latents)
224
+
225
+ # Применяем де-нормализацию VAE как в обучении
226
+ if getattr(self.vae.config, "latents_std", None) is not None:
 
227
  l_mean = torch.tensor(self.vae.config.latents_mean).view(1, -1, 1, 1, 1).to(device, dtype)
228
  l_std = torch.tensor(self.vae.config.latents_std).view(1, -1, 1, 1, 1).to(device, dtype)
 
 
 
229
  latents = latents * l_std + l_mean
230
 
231
+ # Декодируем
232
+ with torch.no_grad():
233
+ image_output = self.vae.decode(latents.to(self.vae.dtype), return_dict=False)[0]
234
 
235
  if image_output.ndim == 5:
236
+ image_output = image_output.squeeze(2) # Убираем временную ось (Frames=1)
237
+
238
  image_output = (image_output.clamp(-1, 1) + 1) / 2
239
  image_np = image_output.cpu().permute(0, 2, 3, 1).float().numpy()
240
+ image_np = np.nan_to_num(image_np, nan=0.0)
241
+
 
 
242
  if output_type == "pil":
243
+ images = [Image.fromarray((img * 255).round().astype("uint8")) for img in image_np]
244
  else:
245
  images = image_np
246
+
247
+ return SdxsPipelineOutput(images=images, prompt=prompt)
 
 
refined.jpg ADDED

Git LFS Details

  • SHA256: b08900de198c3d22e7e5dea378269caf74a681d050e677a8c4c299f35fd1f34f
  • Pointer size: 131 Bytes
  • Size of remote file: 104 kB
scheduler/.ipynb_checkpoints/scheduler_config-checkpoint.json DELETED
@@ -1,22 +0,0 @@
1
- {
2
- "_class_name": "FlowMatchEulerDiscreteScheduler",
3
- "_diffusers_version": "0.34.0.dev0",
4
- "base_image_seq_len": 256,
5
- "base_shift": 0.5,
6
- "final_sigmas_type": "sigma_min",
7
- "invert_sigmas": false,
8
- "max_image_seq_len": 4096,
9
- "max_shift": 1.15,
10
- "num_train_timesteps": 1000,
11
- "shift": 1.0,
12
- "shift_terminal": null,
13
- "sigma_data": 1.0,
14
- "sigma_max": 80.0,
15
- "sigma_min": 0.002,
16
- "stochastic_sampling": false,
17
- "time_shift_type": "exponential",
18
- "use_beta_sigmas": false,
19
- "use_dynamic_shifting": false,
20
- "use_exponential_sigmas": false,
21
- "use_karras_sigmas": true
22
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
scheduler/scheduler_config.json CHANGED
@@ -1,22 +1,3 @@
1
- {
2
- "_class_name": "FlowMatchEulerDiscreteScheduler",
3
- "_diffusers_version": "0.34.0.dev0",
4
- "base_image_seq_len": 256,
5
- "base_shift": 0.5,
6
- "final_sigmas_type": "sigma_min",
7
- "invert_sigmas": false,
8
- "max_image_seq_len": 4096,
9
- "max_shift": 1.15,
10
- "num_train_timesteps": 1000,
11
- "shift": 1.0,
12
- "shift_terminal": null,
13
- "sigma_data": 1.0,
14
- "sigma_max": 80.0,
15
- "sigma_min": 0.002,
16
- "stochastic_sampling": false,
17
- "time_shift_type": "exponential",
18
- "use_beta_sigmas": false,
19
- "use_dynamic_shifting": false,
20
- "use_exponential_sigmas": false,
21
- "use_karras_sigmas": true
22
- }
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:65b3e9ccde6e3727aab1c612e7279599f861aec2fb9354880ab9ef8753c492b6
3
+ size 485
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
test.ipynb CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:677906d20fb691440965fb107de2c9d8e9b7c75884d9e3e15b4375f4257df8ae
3
- size 21416092
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:feff1f3730b8dae616e3ffd24b2f74dcd9c6776c46e00ac72018e0de74785d06
3
+ size 18136603
train-sdxs2b.py CHANGED
@@ -17,7 +17,7 @@ from torch.utils.data import DataLoader, Sampler
17
  from torch.optim.lr_scheduler import LambdaLR
18
  from collections import defaultdict
19
  from accelerate import Accelerator
20
- from datasets import load_from_disk
21
  from tqdm import tqdm
22
  from PIL import Image, ImageOps
23
  from torch.utils.checkpoint import checkpoint
@@ -33,7 +33,7 @@ os.environ["NCCL_IB_DISABLE"] = "1" # comment this on H100!
33
  os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
34
 
35
  # --------------------------- Параметры ---------------------------
36
- ds_path = "/root/ds12345_640_vae_qwen"
37
  project = "transformer"
38
 
39
  gpu_mem_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
@@ -81,8 +81,8 @@ torch.backends.cuda.matmul.allow_tf32 = True
81
  torch.backends.cudnn.allow_tf32 = True
82
  torch.backends.cuda.enable_flash_sdp(True)
83
  torch.backends.cuda.enable_mem_efficient_sdp(True)
84
- torch.backends.cuda.enable_math_sdp(False)
85
- save_barrier = 1.25
86
  warmup_percent = 0.0025
87
  betta2 = 0.997
88
  eps = 1e-6
@@ -223,7 +223,7 @@ def encode_texts(text, max_length=max_length):
223
  for i, length in enumerate(lengths):
224
  hidden[i, length:] = 0
225
 
226
- return hidden, toks.attention_mask.to(dtype=torch.int64)
227
 
228
  @torch.no_grad()
229
  def encode_texts_fast(text, max_length=max_length):
@@ -244,7 +244,7 @@ def encode_texts_fast(text, max_length=max_length):
244
  for i, length in enumerate(lengths):
245
  last_hidden[i, length:] = 0
246
 
247
- return last_hidden, toks.attention_mask.to(dtype=torch.int64)
248
 
249
  shift_factor = getattr(vae.config, "shift_factor", 0.0)
250
  if shift_factor is None:
@@ -375,7 +375,7 @@ def get_fixed_samples_by_resolution(dataset, samples_per_group=1):
375
  masks = torch.tensor(
376
  np.array([item["attention_mask"] for item in samples_data]),
377
  device=device,
378
- dtype=torch.int64
379
  )
380
  else:
381
  embeddings, masks = encode_texts(texts,max_length)
@@ -388,7 +388,30 @@ def get_fixed_samples_by_resolution(dataset, samples_per_group=1):
388
  if limit > 0:
389
  dataset = load_from_disk(ds_path).select(range(limit))
390
  else:
391
- dataset = load_from_disk(ds_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
392
 
393
  print(f"images: {len(dataset)}")
394
 
@@ -424,7 +447,7 @@ def collate_fn_simple(batch):
424
  ]
425
 
426
  embeddings, attention_mask = encode_texts(texts,max_length)
427
- attention_mask = attention_mask.to(dtype=torch.int64)
428
 
429
  return latents, embeddings, attention_mask
430
 
@@ -552,7 +575,7 @@ def get_negative_embedding(neg_prompt="", batch_size=1):
552
  hidden_dim = 2048
553
  seq_len = max_length
554
  empty_emb = torch.zeros((batch_size, seq_len, hidden_dim), dtype=dtype, device=device)
555
- empty_mask = torch.ones((batch_size, seq_len), dtype=torch.int64, device=device)
556
  return empty_emb, empty_mask
557
 
558
  uncond_emb, uncond_mask = encode_texts([neg_prompt],max_length)
 
17
  from torch.optim.lr_scheduler import LambdaLR
18
  from collections import defaultdict
19
  from accelerate import Accelerator
20
+ from datasets import load_from_disk,concatenate_datasets
21
  from tqdm import tqdm
22
  from PIL import Image, ImageOps
23
  from torch.utils.checkpoint import checkpoint
 
33
  os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
34
 
35
  # --------------------------- Параметры ---------------------------
36
+ ds_path = "datasets/dsb_640_vae_qwen_temp"
37
  project = "transformer"
38
 
39
  gpu_mem_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
 
81
  torch.backends.cudnn.allow_tf32 = True
82
  torch.backends.cuda.enable_flash_sdp(True)
83
  torch.backends.cuda.enable_mem_efficient_sdp(True)
84
+ torch.backends.cuda.enable_math_sdp(True)
85
+ save_barrier = 1.4
86
  warmup_percent = 0.0025
87
  betta2 = 0.997
88
  eps = 1e-6
 
223
  for i, length in enumerate(lengths):
224
  hidden[i, length:] = 0
225
 
226
+ return hidden, toks.attention_mask.to(dtype=torch.bool)
227
 
228
  @torch.no_grad()
229
  def encode_texts_fast(text, max_length=max_length):
 
244
  for i, length in enumerate(lengths):
245
  last_hidden[i, length:] = 0
246
 
247
+ return last_hidden, toks.attention_mask.to(dtype=torch.bool)
248
 
249
  shift_factor = getattr(vae.config, "shift_factor", 0.0)
250
  if shift_factor is None:
 
375
  masks = torch.tensor(
376
  np.array([item["attention_mask"] for item in samples_data]),
377
  device=device,
378
+ dtype=torch.bool
379
  )
380
  else:
381
  embeddings, masks = encode_texts(texts,max_length)
 
388
  if limit > 0:
389
  dataset = load_from_disk(ds_path).select(range(limit))
390
  else:
391
+ print(">>> Поиск чанков датасета...")
392
+ chunks = []
393
+ for d in os.listdir(ds_path):
394
+ full_p = os.path.join(ds_path, d)
395
+ if os.path.isdir(full_p):
396
+ chunks.append(full_p)
397
+
398
+ if not chunks:
399
+ print("❌ Чанки не найдены!")
400
+
401
+ print(f">>> Найдено чанков: {len(chunks)}. Начинаю загрузку и объединение...")
402
+
403
+ # 2. Ленивая загрузка всех чанков
404
+ # load_from_disk не ест RAM, пока мы не обращаемся к данным
405
+ ds_list = []
406
+ for c in chunks:
407
+ try:
408
+ ds_list.append(load_from_disk(c))
409
+ except Exception as e:
410
+ print(f"⚠️ Ошибка загрузки чанка {c}: {e}")
411
+
412
+ # 3. Конкатенация (создает виртуальный объединенный датасет)
413
+ dataset = concatenate_datasets(ds_list)
414
+ #dataset = load_from_disk(ds_path)
415
 
416
  print(f"images: {len(dataset)}")
417
 
 
447
  ]
448
 
449
  embeddings, attention_mask = encode_texts(texts,max_length)
450
+ attention_mask = attention_mask.to(dtype=torch.bool)
451
 
452
  return latents, embeddings, attention_mask
453
 
 
575
  hidden_dim = 2048
576
  seq_len = max_length
577
  empty_emb = torch.zeros((batch_size, seq_len, hidden_dim), dtype=dtype, device=device)
578
+ empty_mask = torch.ones((batch_size, seq_len), dtype=torch.bool, device=device)
579
  return empty_emb, empty_mask
580
 
581
  uncond_emb, uncond_mask = encode_texts([neg_prompt],max_length)
transformer/diffusion_pytorch_model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:3188c7ffba9a6fb9e646536503a99a8e0b1530251793ab1f5ff4b73b4df04542
3
- size 7825687184
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ee77d7083d1f968607fbcc531deae347d72c2fb229bbe40356e44a8edae26aec
3
+ size 3912877104
wandb/debug-internal.log DELETED
The diff for this file is too large to render. See raw diff
 
wandb/debug-internal.log ADDED
@@ -0,0 +1 @@
 
 
1
+ run-20260513_080408-xhrf3max/logs/debug-internal.log
wandb/debug.log DELETED
@@ -1,19 +0,0 @@
1
- 2026-05-05 07:53:13,457 INFO MainThread:43955 [wandb_setup.py:_flush():81] Current SDK version is 0.26.1
2
- 2026-05-05 07:53:13,457 INFO MainThread:43955 [wandb_setup.py:_flush():81] Configure stats pid to 43955
3
- 2026-05-05 07:53:13,457 INFO MainThread:43955 [wandb_setup.py:_flush():81] Loading settings from environment variables
4
- 2026-05-05 07:53:13,457 INFO MainThread:43955 [wandb_init.py:setup_run_log_directory():723] Logging user logs to /workspace/2b/wandb/run-20260505_075313-ti70f47q/logs/debug.log
5
- 2026-05-05 07:53:13,458 INFO MainThread:43955 [wandb_init.py:setup_run_log_directory():724] Logging internal logs to /workspace/2b/wandb/run-20260505_075313-ti70f47q/logs/debug-internal.log
6
- 2026-05-05 07:53:13,458 INFO MainThread:43955 [wandb_init.py:init():850] calling init triggers
7
- 2026-05-05 07:53:13,458 INFO MainThread:43955 [wandb_init.py:init():855] wandb.init called with sweep_config: {}
8
- config: {'batch_size': 24, 'base_learning_rate': 1.3333333333333335e-05, 'num_epochs': 1, 'optimizer_type': 'adafactor', '_wandb': {}}
9
- 2026-05-05 07:53:13,458 INFO MainThread:43955 [wandb_init.py:init():898] starting backend
10
- 2026-05-05 07:53:13,663 INFO MainThread:43955 [wandb_init.py:init():913] sending inform_init request
11
- 2026-05-05 07:53:13,842 INFO MainThread:43955 [wandb_init.py:init():918] backend started and connected
12
- 2026-05-05 07:53:13,844 INFO MainThread:43955 [wandb_init.py:init():988] updated telemetry
13
- 2026-05-05 07:53:13,845 INFO MainThread:43955 [wandb_init.py:init():1011] communicating run to backend with 90.0 second timeout
14
- 2026-05-05 07:53:14,174 INFO MainThread:43955 [wandb_init.py:init():1056] starting run threads in backend
15
- 2026-05-05 07:53:14,261 INFO MainThread:43955 [wandb_run.py:_console_start():2554] atexit reg
16
- 2026-05-05 07:53:14,262 INFO MainThread:43955 [wandb_run.py:_redirect():2403] redirect: wrap_raw
17
- 2026-05-05 07:53:14,262 INFO MainThread:43955 [wandb_run.py:_redirect():2472] Wrapping output streams.
18
- 2026-05-05 07:53:14,262 INFO MainThread:43955 [wandb_run.py:_redirect():2495] Redirects installed.
19
- 2026-05-05 07:53:14,267 INFO MainThread:43955 [wandb_init.py:init():1094] run started, returning control to user process
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
wandb/debug.log ADDED
@@ -0,0 +1 @@
 
 
1
+ run-20260513_080408-xhrf3max/logs/debug.log