recoilme commited on
Commit
b4d8bb3
·
1 Parent(s): 8e4732e
.gitignore CHANGED
@@ -8,7 +8,7 @@ src/samples
8
  # cache
9
  cache
10
  datasets/mjnj
11
- datasets/640
12
  test
13
  wandb
14
  nohup.out
 
8
  # cache
9
  cache
10
  datasets/mjnj
11
+ datasets/*
12
  test
13
  wandb
14
  nohup.out
datasets/alchemist/data-00000-of-00001.arrow DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:f1016c8d2672f322bf76bb2ddc5b3ddf444a167a19527228efd4cd8f975d64a5
3
- size 186532912
 
 
 
 
datasets/alchemist/dataset_info.json DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:1033764907721e276d6b5032efdfb9479464c73dce5346e27eedc6e28ef40d6e
3
- size 818
 
 
 
 
datasets/alchemist/state.json DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:63b1be8351de9c7703925d48fad88119bf2f7cb1bed54721c94037573d6032b4
3
- size 333
 
 
 
 
datasets/butterfly/data-00000-of-00001.arrow DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:d8479e8b4cf0c3505189c608cedf8b35ab073f14c6b7db0a9e66b75925e1c519
3
- size 53255512
 
 
 
 
datasets/butterfly/state.json DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:f5c2862bf6140b141f42cfafae0901a621cc1165286e9975d734f9ded8a8862f
3
- size 333
 
 
 
 
src/dataset_from_folder_qwen.py ADDED
@@ -0,0 +1,402 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pip install flash-attn --no-build-isolation
2
+ import torch
3
+ import os
4
+ import gc
5
+ import numpy as np
6
+ import random
7
+ import json
8
+ import shutil
9
+ import time
10
+
11
+ from datasets import Dataset, load_from_disk, concatenate_datasets
12
+ from diffusers import AutoencoderKL,AutoencoderKLWan
13
+ from torchvision.transforms import Resize, ToTensor, Normalize, Compose, InterpolationMode, Lambda
14
+ from transformers import AutoModel, AutoImageProcessor, AutoTokenizer, AutoModelForCausalLM
15
+ from typing import Dict, List, Tuple, Optional, Any
16
+ from PIL import Image
17
+ from tqdm import tqdm
18
+ from datetime import timedelta
19
+
20
+ # ---------------- 1️⃣ Настройки ----------------
21
+ dtype = torch.float16
22
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
23
+ batch_size = 5
24
+ min_size = 320 #192 #256 #192
25
+ max_size = 640 #384 #256 #384
26
+ step = 64 #64
27
+ empty_share = 0.05
28
+ limit = 0
29
+ # Основная процедура обработки
30
+ folder_path = "/workspace/sdxs3d/datasets/eshooshoo_all" #alchemist"
31
+ save_path = "/workspace/sdxs3d/datasets/esh640" #"alchemist"
32
+ os.makedirs(save_path, exist_ok=True)
33
+
34
+ # Функция для очистки CUDA памяти
35
+ def clear_cuda_memory():
36
+ if torch.cuda.is_available():
37
+ used_gb = torch.cuda.max_memory_allocated() / 1024**3
38
+ print(f"used_gb: {used_gb:.2f} GB")
39
+ torch.cuda.empty_cache()
40
+ gc.collect()
41
+
42
+ # ---------------- 2️⃣ Загрузка моделей ----------------
43
+ def load_models():
44
+ print("Загрузка моделей...")
45
+ vae = AutoencoderKL.from_pretrained("AiArtLab/sdxs3d",subfolder="vae",torch_dtype=dtype).to(device).eval()
46
+
47
+ model_name = "Qwen/Qwen3-0.6B"
48
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
49
+ model = AutoModelForCausalLM.from_pretrained(
50
+ model_name,
51
+ torch_dtype=dtype,
52
+ device_map=device
53
+ ).eval()
54
+ #tokenizer = AutoTokenizer.from_pretrained('Qwen/Qwen3-Embedding-0.6B', padding_side='left')
55
+ #model = AutoModel.from_pretrained('Qwen/Qwen3-Embedding-0.6B').to("cuda")
56
+ return vae, model, tokenizer
57
+
58
+ vae, model, tokenizer = load_models()
59
+
60
+ shift_factor = getattr(vae.config, "shift_factor", 0.0)
61
+ if shift_factor is None:
62
+ shift_factor = 0.0
63
+
64
+ scaling_factor = getattr(vae.config, "scaling_factor", 1.0)
65
+ if scaling_factor is None:
66
+ scaling_factor = 1.0
67
+
68
+ latents_mean = getattr(vae.config, "latents_mean", None)
69
+ latents_std = getattr(vae.config, "latents_std", None)
70
+
71
+ # ---------------- 3️⃣ Трансформации ----------------
72
+ def get_image_transform(min_size=256, max_size=512, step=64):
73
+ def transform(img, dry_run=False):
74
+ # Сохраняем исходные размеры изображения
75
+ original_width, original_height = img.size
76
+
77
+ # 0. Ресайз: масштабируем изображение, чтобы максимальная сторона была равна max_size
78
+ if original_width >= original_height:
79
+ new_width = max_size
80
+ new_height = int(max_size * original_height / original_width)
81
+ else:
82
+ new_height = max_size
83
+ new_width = int(max_size * original_width / original_height)
84
+
85
+ if new_height < min_size or new_width < min_size:
86
+ # 1. Ресайз: масштабируем изображение, чтобы минимальная сторона была равна min_size
87
+ if original_width <= original_height:
88
+ new_width = min_size
89
+ new_height = int(min_size * original_height / original_width)
90
+ else:
91
+ new_height = min_size
92
+ new_width = int(min_size * original_width / original_height)
93
+
94
+ # 2. Проверка: если одна из сторон превышает max_size, готовимся к обрезке
95
+ crop_width = min(max_size, (new_width // step) * step)
96
+ crop_height = min(max_size, (new_height // step) * step)
97
+
98
+ # Убеждаемся, что размеры обрезки не меньше min_size
99
+ crop_width = max(min_size, crop_width)
100
+ crop_height = max(min_size, crop_height)
101
+
102
+ # Если запрошен только предварительный расчёт размеров
103
+ if dry_run:
104
+ return crop_width, crop_height
105
+
106
+ # Конвертация в RGB и ресайз
107
+ img_resized = img.convert("RGB").resize((new_width, new_height), Image.LANCZOS)
108
+
109
+ # Определение координат обрезки (обрезаем с учетом вотермарок - треть сверху)
110
+ top = (new_height - crop_height) // 3
111
+ left = 0
112
+
113
+ # Обрезка изображения
114
+ img_cropped = img_resized.crop((left, top, left + crop_width, top + crop_height))
115
+
116
+ # Сохраняем итоговые размеры после всех преобразований
117
+ final_width, final_height = img_cropped.size
118
+
119
+ # тензор
120
+ img_tensor = ToTensor()(img_cropped)
121
+ img_tensor = Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])(img_tensor)
122
+ return img_tensor, img_cropped, final_width, final_height
123
+
124
+ return transform
125
+
126
+ # ---------------- 4️⃣ Функции обработки ----------------
127
+ def last_token_pool(last_hidden_states: torch.Tensor,
128
+ attention_mask: torch.Tensor) -> torch.Tensor:
129
+ # Определяем, есть ли left padding
130
+ left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
131
+ if left_padding:
132
+ return last_hidden_states[:, -1]
133
+ else:
134
+ sequence_lengths = attention_mask.sum(dim=1) - 1
135
+ batch_size = last_hidden_states.shape[0]
136
+ return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]
137
+
138
+ def encode_texts_batch(texts, tokenizer, model, device="cuda", max_length=150, normalize=False):
139
+ with torch.inference_mode():
140
+ # Токенизация
141
+ batch = tokenizer(
142
+ texts,
143
+ return_tensors="pt",
144
+ padding="max_length",
145
+ truncation=True,
146
+ max_length=max_length
147
+ ).to(device)
148
+
149
+ # Прогон через модель
150
+ #outputs = model(**batch)
151
+
152
+ # Пулинг по last token
153
+ #embeddings = last_token_pool(outputs.last_hidden_state, batch["attention_mask"])
154
+
155
+ # L2-нормализация (опционально, обычно нужна для семантического поиска)
156
+ #if normalize:
157
+ # embeddings = F.normalize(embeddings, p=2, dim=1)
158
+
159
+ # Прогон через базовую модель (внутри CausalLM)
160
+ outputs = model.model(**batch, output_hidden_states=True)
161
+
162
+ # Берем последний слой (эмбеддинги всех токенов)
163
+ hidden_states = outputs.hidden_states[-1] # [B, L, D]
164
+
165
+ # Можно применить нормализацию по каждому токену (как в CLIP)
166
+ if normalize:
167
+ hidden_states = F.normalize(hidden_states, p=2, dim=-1)
168
+
169
+ return hidden_states.cpu().numpy() # embeddings.unsqueeze(1).cpu().numpy()
170
+
171
+ def clean_label(label):
172
+ label = label.replace("Image 1", "").replace("Image 2", "").replace("Image 3", "").replace("Image 4", "")
173
+ return label
174
+
175
+ def process_labels_for_guidance(original_labels, prob_to_make_empty=0.01):
176
+ """
177
+ Обрабатывает список меток для classifier-free guidance.
178
+
179
+ С вероятностью prob_to_make_empty:
180
+ - Метка в первом списке заменяется на пустую строку.
181
+ - К метке во втором списке добавляется префикс "zero:".
182
+
183
+ В противном случае метки в обоих списках остаются оригинальными.
184
+
185
+ """
186
+ labels_for_model = []
187
+ labels_for_logging = []
188
+
189
+ for label in original_labels:
190
+ if random.random() < prob_to_make_empty:
191
+ labels_for_model.append("") # Заменяем на пустую строку для модели
192
+ labels_for_logging.append(f"zero: {label}") # Добавляем префикс для логгирования
193
+ else:
194
+ labels_for_model.append(label) # Оставляем оригинальную метку для модели
195
+ labels_for_logging.append(label) # Оставляем оригинальную метку для логгирования
196
+
197
+ return labels_for_model, labels_for_logging
198
+
199
+ def encode_to_latents(images, texts):
200
+ transform = get_image_transform(min_size, max_size, step)
201
+
202
+ try:
203
+ # Обработка изображений (все одинакового размера)
204
+ transformed_tensors = []
205
+ pil_images = []
206
+ widths, heights = [], []
207
+
208
+ # Применяем трансформацию ко всем изображениям
209
+ for img in images:
210
+ try:
211
+ t_img, pil_img, w, h = transform(img)
212
+ transformed_tensors.append(t_img)
213
+ pil_images.append(pil_img)
214
+ widths.append(w)
215
+ heights.append(h)
216
+ except Exception as e:
217
+ print(f"Ошибка трансформации: {e}")
218
+ continue
219
+
220
+ if not transformed_tensors:
221
+ return None
222
+
223
+ # Создаём батч
224
+ batch_tensor = torch.stack(transformed_tensors).to(device, dtype)
225
+ if batch_tensor.ndim==5:
226
+ batch_tensor = batch_tensor.unsqueeze(2) # [B, C, 1, H, W]
227
+
228
+ # Кодируем батч
229
+ with torch.no_grad():
230
+ posteriors = vae.encode(batch_tensor).latent_dist.mode()
231
+ latents = (posteriors - shift_factor) / scaling_factor
232
+
233
+ latents_np = latents.to(dtype).cpu().numpy()
234
+
235
+ # Обрабатываем тексты
236
+ text_labels = [clean_label(text) for text in texts]
237
+
238
+ model_prompts, text_labels = process_labels_for_guidance(text_labels, empty_share)
239
+ embeddings = encode_texts_batch(model_prompts, tokenizer, model)
240
+
241
+ return {
242
+ "vae": latents_np,
243
+ "embeddings": embeddings,
244
+ "text": text_labels,
245
+ "width": widths,
246
+ "height": heights
247
+ }
248
+
249
+ except Exception as e:
250
+ print(f"Критическая ошибка в encode_to_latents: {e}")
251
+ raise
252
+
253
+
254
+ # ---------------- 5️⃣ Обработка папки с изображениями и текстами ----------------
255
+ def process_folder(folder_path, limit=None):
256
+ """
257
+ Рекурсивно обходит указанную директорию и все вложенные директории,
258
+ собирая пути к изображениям и соответствующим текстовым файлам.
259
+ """
260
+ image_paths = []
261
+ text_paths = []
262
+ width = []
263
+ height = []
264
+ transform = get_image_transform(min_size, max_size, step)
265
+
266
+ # Используем os.walk для рекурсивного обхода директорий
267
+ for root, dirs, files in os.walk(folder_path):
268
+ for filename in files:
269
+ # Проверяем, является ли файл изображением
270
+ if filename.lower().endswith((".jpg", ".jpeg", ".png")):
271
+ image_path = os.path.join(root, filename)
272
+ try:
273
+ img = Image.open(image_path)
274
+ except Exception as e:
275
+ print(f"Ошибка при открытии {image_path}: {e}")
276
+ os.remove(image_path)
277
+ text_path = os.path.splitext(image_path)[0] + ".txt"
278
+ if os.path.exists(text_path):
279
+ os.remove(text_path)
280
+ continue
281
+ # Применяем трансформацию только для получения размеров
282
+ w, h = transform(img, dry_run=True)
283
+ # Формируем путь к текстовому файлу
284
+ text_path = os.path.splitext(image_path)[0] + ".txt"
285
+
286
+ # Добавляем пути, если текстовый файл существует
287
+ if os.path.exists(text_path) and min(w, h)>0:
288
+ image_paths.append(image_path)
289
+ text_paths.append(text_path)
290
+ width.append(w) # Добавляем в список
291
+ height.append(h) # Добавляем в список
292
+
293
+ # Проверяем ограничение на количество
294
+ if limit and limit>0 and len(image_paths) >= limit:
295
+ print(f"Достигнут лимит в {limit} изображений")
296
+ return image_paths, text_paths, width, height
297
+
298
+ print(f"Найдено {len(image_paths)} изображений с текстовыми описаниями")
299
+ return image_paths, text_paths, width, height
300
+
301
+ def process_in_chunks(image_paths, text_paths, width, height, chunk_size=50000, batch_size=1):
302
+ total_files = len(image_paths)
303
+ start_time = time.time()
304
+ chunks = range(0, total_files, chunk_size)
305
+
306
+ for chunk_idx, start in enumerate(chunks, 1):
307
+ end = min(start + chunk_size, total_files)
308
+ chunk_image_paths = image_paths[start:end]
309
+ chunk_text_paths = text_paths[start:end]
310
+ chunk_widths = width[start:end] if isinstance(width, list) else [width] * len(chunk_image_paths)
311
+ chunk_heights = height[start:end] if isinstance(height, list) else [height] * len(chunk_image_paths)
312
+
313
+ # Чтение текстов
314
+ chunk_texts = []
315
+ for text_path in chunk_text_paths:
316
+ try:
317
+ with open(text_path, 'r', encoding='utf-8') as f:
318
+ text = f.read().strip()
319
+ chunk_texts.append(text)
320
+ except Exception as e:
321
+ print(f"Ошибка чтения {text_path}: {e}")
322
+ chunk_texts.append("")
323
+
324
+ # Группируем изображения по размерам
325
+ size_groups = {}
326
+ for i in range(len(chunk_image_paths)):
327
+ size_key = (chunk_widths[i], chunk_heights[i])
328
+ if size_key not in size_groups:
329
+ size_groups[size_key] = {"image_paths": [], "texts": []}
330
+ size_groups[size_key]["image_paths"].append(chunk_image_paths[i])
331
+ size_groups[size_key]["texts"].append(chunk_texts[i])
332
+
333
+ # Обрабатываем каждую группу размеров отдельно
334
+ for size_key, group_data in size_groups.items():
335
+ print(f"Обработка группы с размером {size_key[0]}x{size_key[1]} - {len(group_data['image_paths'])} изображений")
336
+
337
+ group_dataset = Dataset.from_dict({
338
+ "image_path": group_data["image_paths"],
339
+ "text": group_data["texts"]
340
+ })
341
+
342
+ # Теперь можно использовать указанный batch_size, т.к. все изображения одного размера
343
+ processed_group = group_dataset.map(
344
+ lambda examples: encode_to_latents(
345
+ [Image.open(path) for path in examples["image_path"]],
346
+ examples["text"]
347
+ ),
348
+ batched=True,
349
+ batch_size=batch_size,
350
+ #remove_columns=["image_path"],
351
+ desc=f"Обработка группы размера {size_key[0]}x{size_key[1]}"
352
+ )
353
+
354
+ # Сохраняем результаты группы
355
+ group_save_path = f"{save_path}_temp/chunk_{chunk_idx}_size_{size_key[0]}x{size_key[1]}"
356
+ processed_group.save_to_disk(group_save_path)
357
+ clear_cuda_memory()
358
+ elapsed = time.time() - start_time
359
+ processed = (chunk_idx - 1) * chunk_size + sum([len(sg["image_paths"]) for sg in list(size_groups.values())[:list(size_groups.values()).index(group_data) + 1]])
360
+ if processed > 0:
361
+ remaining = (elapsed / processed) * (total_files - processed)
362
+ elapsed_str = str(timedelta(seconds=int(elapsed)))
363
+ remaining_str = str(timedelta(seconds=int(remaining)))
364
+ print(f"ETA: Прошло {elapsed_str}, Осталось {remaining_str}, Прогресс {processed}/{total_files} ({processed/total_files:.1%})")
365
+
366
+ # ---------------- 7️⃣ Объединение чанков ----------------
367
+ def combine_chunks(temp_path, final_path):
368
+ """Объединение обработанных чанков в финальный датасет"""
369
+ chunks = sorted([
370
+ os.path.join(temp_path, d)
371
+ for d in os.listdir(temp_path)
372
+ if d.startswith("chunk_")
373
+ ])
374
+
375
+ datasets = [load_from_disk(chunk) for chunk in chunks]
376
+ combined = concatenate_datasets(datasets)
377
+ combined.save_to_disk(final_path)
378
+
379
+ print(f"✅ Датасет успешно сохранен в: {final_path}")
380
+
381
+
382
+
383
+ # Создаем временную папку для чанков
384
+ temp_path = f"{save_path}_temp"
385
+ os.makedirs(temp_path, exist_ok=True)
386
+
387
+ # Получаем список файлов
388
+ image_paths, text_paths, width, height = process_folder(folder_path,limit)
389
+ print(f"Всего найдено {len(image_paths)} изображений")
390
+
391
+ # Обработка с чанкованием
392
+ process_in_chunks(image_paths, text_paths, width, height, chunk_size=100000, batch_size=batch_size)
393
+
394
+ # Объединение чанков в финальный датасет
395
+ combine_chunks(temp_path, save_path)
396
+
397
+ # Удаление временной папки
398
+ try:
399
+ shutil.rmtree(temp_path)
400
+ print(f"✅ Временная папка {temp_path} успешно удалена")
401
+ except Exception as e:
402
+ print(f"⚠️ Ошибка при удалении временной папки: {e}")
datasets/butterfly/dataset_info.json → src/result_grid2.png RENAMED
File without changes
src/sample-Copy1.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
src/sample.ipynb CHANGED
The diff for this file is too large to render. See raw diff
 
train.py CHANGED
@@ -26,14 +26,14 @@ import torch.nn.functional as F
26
  from collections import deque
27
 
28
  # --------------------------- Параметры ---------------------------
29
- ds_path = "/workspace/sdxs3d/datasets/640"
30
  project = "unet"
31
  batch_size = 64
32
  base_learning_rate = 6e-5
33
  min_learning_rate = 1e-5
34
  num_epochs = 80
35
  # samples/save per epoch
36
- sample_interval_share = 2
37
  use_wandb = True
38
  use_comet_ml = False
39
  save_model = True
 
26
  from collections import deque
27
 
28
  # --------------------------- Параметры ---------------------------
29
+ ds_path = "/workspace/sdxs3d/datasets/esh640"
30
  project = "unet"
31
  batch_size = 64
32
  base_learning_rate = 6e-5
33
  min_learning_rate = 1e-5
34
  num_epochs = 80
35
  # samples/save per epoch
36
+ sample_interval_share = 20
37
  use_wandb = True
38
  use_comet_ml = False
39
  save_model = True
unet/config.json CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:0ef8fbaff98c8d479d68b566d07ef4fb8e51ac26b9e8b5a3cb2b23f9a978f6ca
3
- size 1874
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:afc06beff07034f0ce9f671c83222e7f78eedc3b3ce93293143accdebef1b111
3
+ size 1887
unet/diffusion_pytorch_model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:5906a60153dc8273125a16770fab6015d15f01935fd7b44762b76482d313e346
3
- size 6184944280
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9d933d318f2d42b37c31065a09c14ee0c03ec05a10d672667743a089d396086b
3
+ size 3092571208