recoilme commited on
Commit
19de470
·
1 Parent(s): 4265f4a
dataset.py CHANGED
@@ -16,22 +16,19 @@ from typing import Dict, List, Tuple, Optional, Any
16
  from PIL import Image
17
  from tqdm import tqdm
18
  from datetime import timedelta
19
- import subprocess
20
- import tempfile
21
 
22
  # ---------------- 1️⃣ Настройки ----------------
23
  dtype = torch.float16
24
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
25
  batch_size = 10
26
- min_size = 320 #192 #256 #192
27
- max_size = 640 #384 #256 #384
28
- step = 64 #64
29
- empty_share = 0.05
30
  limit = 0
31
  # Основная процедура обработки
32
- folder_path = "/workspace/tar" #alchemist"
33
- save_path = "/workspace/640" #"alchemist"
34
- dir_tmp = "/workspace/tmp"
35
  os.makedirs(save_path, exist_ok=True)
36
 
37
  # Функция для очистки CUDA памяти
@@ -45,18 +42,21 @@ def clear_cuda_memory():
45
  # ---------------- 2️⃣ Загрузка моделей ----------------
46
  def load_models():
47
  print("Загрузка моделей...")
48
- vae = AutoencoderKL.from_pretrained("AiArtLab/sdxs3d",subfolder="vae",torch_dtype=dtype).to(device).eval()
49
-
50
- model_name = "Qwen/Qwen3-0.6B"
51
- tokenizer = AutoTokenizer.from_pretrained(model_name)
52
- model = AutoModelForCausalLM.from_pretrained(
53
- model_name,
54
- torch_dtype=dtype,
55
- device_map=device
56
- ).eval()
57
- return vae, model, tokenizer
58
-
59
- vae, model, tokenizer = load_models()
 
 
 
60
 
61
  shift_factor = getattr(vae.config, "shift_factor", 0.0)
62
  if shift_factor is None:
@@ -136,7 +136,7 @@ def last_token_pool(last_hidden_states: torch.Tensor,
136
  batch_size = last_hidden_states.shape[0]
137
  return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]
138
 
139
- def encode_texts_batch(texts, tokenizer, model, device="cuda", max_length=150):
140
  with torch.inference_mode():
141
  # Токенизация
142
  batch = tokenizer(
@@ -147,16 +147,32 @@ def encode_texts_batch(texts, tokenizer, model, device="cuda", max_length=150):
147
  max_length=max_length
148
  ).to(device)
149
 
 
 
 
 
 
 
 
 
 
 
150
  # Прогон через базовую модель (внутри CausalLM)
151
  outputs = model.model(**batch, output_hidden_states=True)
152
 
153
  # Берем последний слой (эмбеддинги всех токенов)
154
- hidden_states = outputs.last_hidden_state
 
 
 
 
155
 
156
  return hidden_states.cpu().numpy() # embeddings.unsqueeze(1).cpu().numpy()
157
 
158
  def clean_label(label):
159
- label = label.replace("Image 1", "").replace("Image 2", "").replace("Image 3", "").replace("Image 4", "")
 
 
160
  return label
161
 
162
  def process_labels_for_guidance(original_labels, prob_to_make_empty=0.01):
@@ -223,11 +239,11 @@ def encode_to_latents(images, texts):
223
  text_labels = [clean_label(text) for text in texts]
224
 
225
  model_prompts, text_labels = process_labels_for_guidance(text_labels, empty_share)
226
- embeddings = encode_texts_batch(model_prompts, tokenizer, model)
227
 
228
  return {
229
  "vae": latents_np,
230
- "embeddings": embeddings,
231
  "text": text_labels,
232
  "width": widths,
233
  "height": heights
@@ -341,10 +357,6 @@ def process_in_chunks(image_paths, text_paths, width, height, chunk_size=10000,
341
  # Сохраняем результаты группы
342
  group_save_path = f"{save_path}_temp/chunk_{chunk_idx}_size_{size_key[0]}x{size_key[1]}"
343
  processed_group.save_to_disk(group_save_path)
344
-
345
- subprocess.run(["tar", "-I", "zstd", "-cf", f"{group_save_path}.tar.zst", "-C", group_save_path, "."], check=True)
346
- shutil.rmtree(group_save_path)
347
-
348
  clear_cuda_memory()
349
  elapsed = time.time() - start_time
350
  processed = (chunk_idx - 1) * chunk_size + sum([len(sg["image_paths"]) for sg in list(size_groups.values())[:list(size_groups.values()).index(group_data) + 1]])
@@ -354,60 +366,22 @@ def process_in_chunks(image_paths, text_paths, width, height, chunk_size=10000,
354
  remaining_str = str(timedelta(seconds=int(remaining)))
355
  print(f"ETA: Прошло {elapsed_str}, Осталось {remaining_str}, Прогресс {processed}/{total_files} ({processed/total_files:.1%})")
356
 
357
- # ---------------- 7️⃣ Объединение ----------------
358
- def safe_repack(ds):
359
- """
360
- Перепаковка датасета в маленькие .arrow файлы.
361
- ЧТО ДЕЛАЕТ:
362
- - уменьшает Arrow chunk size
363
- - writer_batch_size=1000 → Arrow 30–60 MB
364
- """
365
- return ds.map(lambda x: x, batched=True, writer_batch_size=1000)
366
-
367
-
368
  def combine_chunks(temp_path, final_path):
369
- archives = sorted(
370
- f for f in os.listdir(temp_path)
371
- if f.endswith(".tar.zst")
372
- )
373
- archives = [os.path.join(temp_path, f) for f in archives]
374
-
375
- print(f"Найдено {len(archives)} архивов.")
376
-
377
- # Инициализируем пустой датасет
378
- merged_ds = None
379
-
380
- for i, arc in enumerate(archives):
381
- print(f"[{i+1}/{len(archives)}] Обрабатываю {arc}")
382
-
383
- # Распаковка
384
- tmp = tempfile.mkdtemp(dir=dir_tmp)
385
- subprocess.run(["tar", "-I", "zstd", "-xf", arc, "-C", tmp], check=True)
386
-
387
- # Загрузка датасета
388
- ds = load_from_disk(tmp)
389
-
390
- # Перепаковка чанк-датасета, чтобы уменьшить Arrow-файлы
391
- ds = safe_repack(ds)
392
-
393
- # Мерж
394
- if merged_ds is None:
395
- merged_ds = ds
396
- else:
397
- merged_ds = concatenate_datasets([merged_ds, ds])
398
-
399
- # cleanup
400
- shutil.rmtree(tmp)
401
- os.remove(arc)
402
-
403
- # Финальная перепаковка
404
- print("⚙️ Финальная перепаковка...")
405
- merged_ds = safe_repack(merged_ds)
406
-
407
- print("💾 Финальное сохранение...")
408
- merged_ds.save_to_disk(final_path)
409
 
410
- print(f"✅ Датасет сохранён: {final_path}")
411
 
412
  # Создаем временную папку для чанков
413
  temp_path = f"{save_path}_temp"
@@ -418,7 +392,7 @@ image_paths, text_paths, width, height = process_folder(folder_path,limit)
418
  print(f"Всего найдено {len(image_paths)} изображений")
419
 
420
  # Обработка с чанкованием
421
- process_in_chunks(image_paths, text_paths, width, height, chunk_size=10000, batch_size=batch_size)
422
 
423
  # Объединение чанков в финальный датасет
424
  combine_chunks(temp_path, save_path)
 
16
  from PIL import Image
17
  from tqdm import tqdm
18
  from datetime import timedelta
 
 
19
 
20
  # ---------------- 1️⃣ Настройки ----------------
21
  dtype = torch.float16
22
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
23
  batch_size = 10
24
+ min_size = 384 #320 #192 #256 #192
25
+ max_size = 768 #640 #384 #256 #384
26
+ step = 32 #64
27
+ empty_share = 0.0
28
  limit = 0
29
  # Основная процедура обработки
30
+ folder_path = "/workspace/dataset/dataset/ae3" #alchemist"
31
+ save_path = "/workspace/ae3_768" #"alchemist"
 
32
  os.makedirs(save_path, exist_ok=True)
33
 
34
  # Функция для очистки CUDA памяти
 
42
  # ---------------- 2️⃣ Загрузка моделей ----------------
43
  def load_models():
44
  print("Загрузка моделей...")
45
+ vae = AutoencoderKL.from_pretrained("AiArtLab/sdxs",subfolder="vae1x",torch_dtype=dtype).to(device).eval()
46
+
47
+ #model_name = "Qwen/Qwen3-0.6B"
48
+ #tokenizer = AutoTokenizer.from_pretrained(model_name)
49
+ #model = AutoModelForCausalLM.from_pretrained(
50
+ # model_name,
51
+ # torch_dtype=dtype,
52
+ # device_map=device
53
+ #).eval()
54
+ #tokenizer = AutoTokenizer.from_pretrained('Qwen/Qwen3-Embedding-0.6B', padding_side='left')
55
+ #model = AutoModel.from_pretrained('Qwen/Qwen3-Embedding-0.6B').to("cuda")
56
+ return vae#, model, tokenizer
57
+
58
+ #vae, model, tokenizer = load_models()
59
+ vae = load_models()
60
 
61
  shift_factor = getattr(vae.config, "shift_factor", 0.0)
62
  if shift_factor is None:
 
136
  batch_size = last_hidden_states.shape[0]
137
  return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]
138
 
139
+ def encode_texts_batch(texts, tokenizer, model, device="cuda", max_length=150, normalize=False):
140
  with torch.inference_mode():
141
  # Токенизация
142
  batch = tokenizer(
 
147
  max_length=max_length
148
  ).to(device)
149
 
150
+ # Прогон через модель
151
+ #outputs = model(**batch)
152
+
153
+ # Пулинг по last token
154
+ #embeddings = last_token_pool(outputs.last_hidden_state, batch["attention_mask"])
155
+
156
+ # L2-нормализация (опционально, обычно нужна для семантического поиска)
157
+ #if normalize:
158
+ # embeddings = F.normalize(embeddings, p=2, dim=1)
159
+
160
  # Прогон через базовую модель (внутри CausalLM)
161
  outputs = model.model(**batch, output_hidden_states=True)
162
 
163
  # Берем последний слой (эмбеддинги всех токенов)
164
+ hidden_states = outputs.hidden_states[-1] # [B, L, D]
165
+
166
+ # Можно применить нормализацию по каждому токену (как в CLIP)
167
+ if normalize:
168
+ hidden_states = F.normalize(hidden_states, p=2, dim=-1)
169
 
170
  return hidden_states.cpu().numpy() # embeddings.unsqueeze(1).cpu().numpy()
171
 
172
  def clean_label(label):
173
+ label = label.replace("Image 1", "").replace("Image 2", "").replace("Image 3", "").replace("Image 4", "").replace("The image depicts ","").replace("The image presents ","").replace("The image features ","").replace("The image portrays ","").replace("The image is ","").strip()
174
+ if label.startswith("."):
175
+ label = label[1:].lstrip()
176
  return label
177
 
178
  def process_labels_for_guidance(original_labels, prob_to_make_empty=0.01):
 
239
  text_labels = [clean_label(text) for text in texts]
240
 
241
  model_prompts, text_labels = process_labels_for_guidance(text_labels, empty_share)
242
+ #embeddings = encode_texts_batch(model_prompts, tokenizer, model)
243
 
244
  return {
245
  "vae": latents_np,
246
+ #"embeddings": embeddings,
247
  "text": text_labels,
248
  "width": widths,
249
  "height": heights
 
357
  # Сохраняем результаты группы
358
  group_save_path = f"{save_path}_temp/chunk_{chunk_idx}_size_{size_key[0]}x{size_key[1]}"
359
  processed_group.save_to_disk(group_save_path)
 
 
 
 
360
  clear_cuda_memory()
361
  elapsed = time.time() - start_time
362
  processed = (chunk_idx - 1) * chunk_size + sum([len(sg["image_paths"]) for sg in list(size_groups.values())[:list(size_groups.values()).index(group_data) + 1]])
 
366
  remaining_str = str(timedelta(seconds=int(remaining)))
367
  print(f"ETA: Прошло {elapsed_str}, Осталось {remaining_str}, Прогресс {processed}/{total_files} ({processed/total_files:.1%})")
368
 
369
+ # ---------------- 7️⃣ Объединение чанков ----------------
 
 
 
 
 
 
 
 
 
 
370
  def combine_chunks(temp_path, final_path):
371
+ """Объединение обработанных чанков в финальный датасет"""
372
+ chunks = sorted([
373
+ os.path.join(temp_path, d)
374
+ for d in os.listdir(temp_path)
375
+ if d.startswith("chunk_")
376
+ ])
377
+
378
+ datasets = [load_from_disk(chunk) for chunk in chunks]
379
+ combined = concatenate_datasets(datasets)
380
+ combined.save_to_disk(final_path)
381
+
382
+ print(f"✅ Датасет успешно сохранен в: {final_path}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
383
 
384
+
385
 
386
  # Создаем временную папку для чанков
387
  temp_path = f"{save_path}_temp"
 
392
  print(f"Всего найдено {len(image_paths)} изображений")
393
 
394
  # Обработка с чанкованием
395
+ process_in_chunks(image_paths, text_paths, width, height, chunk_size=20000, batch_size=batch_size)
396
 
397
  # Объединение чанков в финальный датасет
398
  combine_chunks(temp_path, save_path)
samples/unet_320x640_0.jpg CHANGED

Git LFS Details

  • SHA256: e414aa0a90d55a49d737a74608c8f33b57d58f03b0794b1bf6ef1da749258edb
  • Pointer size: 130 Bytes
  • Size of remote file: 55.2 kB

Git LFS Details

  • SHA256: cbd28220db18f95d0fee2e027af5bc8829b9d3ee788c1d94b521d5c37c368983
  • Pointer size: 130 Bytes
  • Size of remote file: 62.4 kB
samples/unet_384x640_0.jpg CHANGED

Git LFS Details

  • SHA256: 4542cf64e389bedc4a59eb55162569e6baf78b333f7b4eb12ede83e8314627f0
  • Pointer size: 131 Bytes
  • Size of remote file: 130 kB

Git LFS Details

  • SHA256: c63700ae77198f9ffe16431218661d6703daaf085b5ebc5b9f5591b582d16bfd
  • Pointer size: 131 Bytes
  • Size of remote file: 159 kB
samples/unet_448x640_0.jpg CHANGED

Git LFS Details

  • SHA256: b5336b5074ddf580d53550dca84bff8e8df33e9c41e14535ae254e1632d8b39c
  • Pointer size: 130 Bytes
  • Size of remote file: 63.1 kB

Git LFS Details

  • SHA256: 465c5bf3e1b611dd046d0d28b1bc4850c2562fef837bf85816ae00959f28f78c
  • Pointer size: 130 Bytes
  • Size of remote file: 81.7 kB
samples/unet_512x640_0.jpg CHANGED

Git LFS Details

  • SHA256: db3f1f8a85554109e263c5315d54db1bd644691f8cfada79008ec82f26eb71b3
  • Pointer size: 131 Bytes
  • Size of remote file: 123 kB

Git LFS Details

  • SHA256: e94978b4e71b7b45166215ce68d794a95dcd5ad4f3ea39f31be1189110fd958f
  • Pointer size: 131 Bytes
  • Size of remote file: 131 kB
samples/unet_576x640_0.jpg CHANGED

Git LFS Details

  • SHA256: 7cde37bc4f9e1e78c48bcc920d5b35e891a8c04166f5cde12c8956de1a0c98f1
  • Pointer size: 131 Bytes
  • Size of remote file: 222 kB

Git LFS Details

  • SHA256: e3cd197d0f08c8da4e5e3c51258c4dae62ee498dfef2d876d9c0e93a7a2d8fc4
  • Pointer size: 131 Bytes
  • Size of remote file: 244 kB
samples/unet_640x320_0.jpg CHANGED

Git LFS Details

  • SHA256: 2dfab00d9532e20fda124e6c03e17eb60bdd0f99411da359d0eac87a5d46ed64
  • Pointer size: 130 Bytes
  • Size of remote file: 73.5 kB

Git LFS Details

  • SHA256: bbaebc70f65baa27de7ef43a6028c9bf11a73d7cf2874e402183edc590006ee8
  • Pointer size: 130 Bytes
  • Size of remote file: 88.3 kB
samples/unet_640x384_0.jpg CHANGED

Git LFS Details

  • SHA256: d8858222a7168440d61651cb636460ccfa6160261a3c1ee02cdcfa64945ed84e
  • Pointer size: 131 Bytes
  • Size of remote file: 115 kB

Git LFS Details

  • SHA256: 31315d6c63ecadd9ba9a9d568ae00047c2236a8d9cf87e9e357e477db0e427fc
  • Pointer size: 131 Bytes
  • Size of remote file: 110 kB
samples/unet_640x448_0.jpg CHANGED

Git LFS Details

  • SHA256: 86ae9eca3fb4ede85edfeebea73dfa5ef37073d5a4ee98c1cd5f3b3fea6227fb
  • Pointer size: 130 Bytes
  • Size of remote file: 97.2 kB

Git LFS Details

  • SHA256: b98d7ce8c61c9670407c5480614e695a8ccfee23c37fd08297edc6e8b4f59355
  • Pointer size: 130 Bytes
  • Size of remote file: 92.2 kB
samples/unet_640x512_0.jpg CHANGED

Git LFS Details

  • SHA256: cb61b32bfc4874bd5ce6e8000b42aa857ae3cbbcdaadc15cb975af3de551ec55
  • Pointer size: 131 Bytes
  • Size of remote file: 106 kB

Git LFS Details

  • SHA256: 5f29af11e61ea1773a84ab8cc9fa880bb60c1fdef4584be9157e9f1f8adb3d19
  • Pointer size: 131 Bytes
  • Size of remote file: 122 kB
samples/unet_640x576_0.jpg CHANGED

Git LFS Details

  • SHA256: acfbd867ca33935affa45e0df87423585e4c491fe19aa18cd89d15853658b128
  • Pointer size: 131 Bytes
  • Size of remote file: 164 kB

Git LFS Details

  • SHA256: e4263e6d85ffbbf5b0fd0cb3545ffb9fb9a081e73098e714e853f196a000845e
  • Pointer size: 131 Bytes
  • Size of remote file: 182 kB
samples/unet_640x640_0.jpg CHANGED

Git LFS Details

  • SHA256: c10691ff09a6e6f55f71a46acba37b0ac1dadc66903b715e9b18b42ccc158ba4
  • Pointer size: 131 Bytes
  • Size of remote file: 157 kB

Git LFS Details

  • SHA256: c839b3deba43cf54c33c903c7df431797f17c5734cf766ae3c0c32d84a3a6847
  • Pointer size: 131 Bytes
  • Size of remote file: 202 kB
src/untar.sh ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ #!/bin/bash
2
+ find . -maxdepth 1 -type f \( -name "*.tar*" -o -name "*.tgz" -o -name "*.tar.bz2" \) -exec sh -c 'tar -xf "{}" && rm "{}"' \;
unet/diffusion_pytorch_model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:8820baa1b4aab525952d765f0162a0fcaf2a93641ffd8146683880c17b31c71e
3
  size 6205958296
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:08b55f95640f0615bc957b7e0641973220578146f32d1647f900fa74c93f1f4d
3
  size 6205958296