| |
| import torch |
| import os |
| import gc |
| import numpy as np |
| import random |
| import json |
| import shutil |
| import time |
|
|
| from datasets import Dataset, load_from_disk, concatenate_datasets |
| from diffusers import AutoencoderKL,AutoencoderKLWan,AsymmetricAutoencoderKL,AutoencoderKLFlux2 |
| from torchvision.transforms import Resize, ToTensor, Normalize, Compose, InterpolationMode, Lambda |
| from transformers import AutoModel, AutoImageProcessor, AutoTokenizer, AutoModelForCausalLM |
| from typing import Dict, List, Tuple, Optional, Any |
| from PIL import Image |
| from tqdm import tqdm |
| from datetime import timedelta |
|
|
| |
| dtype = torch.float16 |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| batch_size = 5 |
| min_size = 640 |
| max_size = 1280 |
| step = 64 |
| empty_share = 0.0 |
| limit = 0 |
| |
| folder_path = "/workspace/ds" |
| save_path = "/workspace/ds1234_flux32" |
| os.makedirs(save_path, exist_ok=True) |
|
|
| |
| def clear_cuda_memory(): |
| if torch.cuda.is_available(): |
| used_gb = torch.cuda.max_memory_allocated() / 1024**3 |
| print(f"used_gb: {used_gb:.2f} GB") |
| torch.cuda.empty_cache() |
| gc.collect() |
|
|
| |
| def load_models(): |
| print("Загрузка моделей...") |
| |
| vae = AutoencoderKL.from_pretrained("vae", torch_dtype=dtype).to(device).eval() |
| return vae |
|
|
| vae = load_models() |
|
|
| shift_factor = getattr(vae.config, "shift_factor", 0.0) |
| if shift_factor is None: |
| shift_factor = 0.0 |
|
|
| scaling_factor = getattr(vae.config, "scaling_factor", 1.0) |
| if scaling_factor is None: |
| scaling_factor = 1.0 |
| |
| mean = getattr(vae.config, "latents_mean", None) |
| std = getattr(vae.config, "latents_std", None) |
| if mean is not None and std is not None: |
| latents_std = torch.tensor(std, device=device, dtype=dtype).view(1, len(std), 1, 1) |
| latents_mean = torch.tensor(mean, device=device, dtype=dtype).view(1, len(mean), 1, 1) |
| |
| |
| def get_image_transform(min_size=256, max_size=512, step=64): |
| def transform(img, dry_run=False): |
| |
| original_width, original_height = img.size |
|
|
| |
| if original_width >= original_height: |
| new_width = max_size |
| new_height = int(max_size * original_height / original_width) |
| else: |
| new_height = max_size |
| new_width = int(max_size * original_width / original_height) |
|
|
| if new_height < min_size or new_width < min_size: |
| |
| if original_width <= original_height: |
| new_width = min_size |
| new_height = int(min_size * original_height / original_width) |
| else: |
| new_height = min_size |
| new_width = int(min_size * original_width / original_height) |
|
|
| |
| crop_width = min(max_size, (new_width // step) * step) |
| crop_height = min(max_size, (new_height // step) * step) |
|
|
| |
| crop_width = max(min_size, crop_width) |
| crop_height = max(min_size, crop_height) |
| |
| |
| if dry_run: |
| return crop_width, crop_height |
| |
| |
| img_resized = img.convert("RGB").resize((new_width, new_height), Image.LANCZOS) |
| |
| |
| top = (new_height - crop_height) // 3 |
| left = 0 |
| |
| |
| img_cropped = img_resized.crop((left, top, left + crop_width, top + crop_height)) |
| |
| |
| final_width, final_height = img_cropped.size |
| |
| |
| img_tensor = ToTensor()(img_cropped) |
| img_tensor = Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])(img_tensor) |
| return img_tensor, img_cropped, final_width, final_height |
|
|
| return transform |
|
|
| |
| def clean_label(label): |
| label = label.replace("Image 1", "").replace("Image 2", "").replace("Image 3", "").replace("Image 4", "").replace("The image depicts ","").replace("The image presents ","").replace("The image features ","").replace("The image portrays ","").replace("The image is ","").strip() |
| if label.startswith("."): |
| label = label[1:].lstrip() |
| return label |
|
|
| def process_labels_for_guidance(original_labels, prob_to_make_empty=0.01): |
| """ |
| Обрабатывает список меток для classifier-free guidance. |
| |
| С вероятностью prob_to_make_empty: |
| - Метка в первом списке заменяется на пустую строку. |
| - К метке во втором списке добавляется префикс "zero:". |
| |
| В противном случае метки в обоих списках остаются оригинальными. |
| |
| """ |
| labels_for_model = [] |
| labels_for_logging = [] |
|
|
| for label in original_labels: |
| if random.random() < prob_to_make_empty: |
| labels_for_model.append("") |
| labels_for_logging.append(f"zero: {label}") |
| else: |
| labels_for_model.append(label) |
| labels_for_logging.append(label) |
|
|
| return labels_for_model, labels_for_logging |
|
|
|
|
| def encode_to_latents(images, texts): |
| transform = get_image_transform(min_size, max_size, step) |
| |
| try: |
| |
| transformed_tensors = [] |
| pil_images = [] |
| widths, heights = [], [] |
| |
| |
| for img in images: |
| try: |
| t_img, pil_img, w, h = transform(img) |
| transformed_tensors.append(t_img) |
| pil_images.append(pil_img) |
| widths.append(w) |
| heights.append(h) |
| except Exception as e: |
| print(f"Ошибка трансформации: {e}") |
| continue |
|
|
| if not transformed_tensors: |
| return None |
|
|
| |
| batch_tensor = torch.stack(transformed_tensors).to(device, dtype) |
| if batch_tensor.ndim==5: |
| batch_tensor = batch_tensor.unsqueeze(2) |
| |
| |
| with torch.no_grad(): |
| posteriors = vae.encode(batch_tensor).latent_dist.mode() |
| if latents_mean is not None and latents_std is not None: |
| posteriors = (posteriors - latents_mean) / latents_std |
| posteriors = (posteriors - shift_factor) / scaling_factor |
| |
| latents_np = posteriors.to(dtype).cpu().numpy() |
|
|
| |
| text_labels = [clean_label(text) for text in texts] |
|
|
| model_prompts, text_labels = process_labels_for_guidance(text_labels, empty_share) |
|
|
| return { |
| "vae": latents_np, |
| "text": text_labels, |
| "width": widths, |
| "height": heights |
| } |
| |
| except Exception as e: |
| print(f"Критическая ошибка в encode_to_latents: {e}") |
| raise |
| |
|
|
| |
| def process_folder(folder_path, limit=None): |
| """ |
| Рекурсивно обходит указанную директорию и все вложенные директории, |
| собирая пути к изображениям и соответствующим текстовым файлам. |
| """ |
| image_paths = [] |
| text_paths = [] |
| width = [] |
| height = [] |
| transform = get_image_transform(min_size, max_size, step) |
| |
| |
| for root, dirs, files in os.walk(folder_path): |
| for filename in files: |
| |
| if filename.lower().endswith((".jpg", ".jpeg", ".png")): |
| image_path = os.path.join(root, filename) |
| try: |
| img = Image.open(image_path) |
| except Exception as e: |
| print(f"Ошибка при открытии {image_path}: {e}") |
| os.remove(image_path) |
| text_path = os.path.splitext(image_path)[0] + ".txt" |
| if os.path.exists(text_path): |
| os.remove(text_path) |
| continue |
| |
| w, h = transform(img, dry_run=True) |
| |
| text_path = os.path.splitext(image_path)[0] + ".txt" |
| |
| |
| if os.path.exists(text_path) and min(w, h)>0: |
| image_paths.append(image_path) |
| text_paths.append(text_path) |
| width.append(w) |
| height.append(h) |
| |
| |
| if limit and limit>0 and len(image_paths) >= limit: |
| print(f"Достигнут лимит в {limit} изображений") |
| return image_paths, text_paths, width, height |
| |
| print(f"Найдено {len(image_paths)} изображений с текстовыми описаниями") |
| return image_paths, text_paths, width, height |
| |
| def process_in_chunks(image_paths, text_paths, width, height, chunk_size=10000, batch_size=1): |
| total_files = len(image_paths) |
| start_time = time.time() |
| chunks = range(0, total_files, chunk_size) |
| |
| for chunk_idx, start in enumerate(chunks, 1): |
| end = min(start + chunk_size, total_files) |
| chunk_image_paths = image_paths[start:end] |
| chunk_text_paths = text_paths[start:end] |
| chunk_widths = width[start:end] if isinstance(width, list) else [width] * len(chunk_image_paths) |
| chunk_heights = height[start:end] if isinstance(height, list) else [height] * len(chunk_image_paths) |
| |
| |
| chunk_texts = [] |
| for text_path in chunk_text_paths: |
| try: |
| with open(text_path, 'r', encoding='utf-8') as f: |
| text = f.read().strip() |
| chunk_texts.append(text) |
| except Exception as e: |
| print(f"Ошибка чтения {text_path}: {e}") |
| chunk_texts.append("") |
| |
| |
| size_groups = {} |
| for i in range(len(chunk_image_paths)): |
| size_key = (chunk_widths[i], chunk_heights[i]) |
| if size_key not in size_groups: |
| size_groups[size_key] = {"image_paths": [], "texts": []} |
| size_groups[size_key]["image_paths"].append(chunk_image_paths[i]) |
| size_groups[size_key]["texts"].append(chunk_texts[i]) |
| |
| |
| for size_key, group_data in size_groups.items(): |
| print(f"Обработка группы с размером {size_key[0]}x{size_key[1]} - {len(group_data['image_paths'])} изображений") |
| |
| group_dataset = Dataset.from_dict({ |
| "image_path": group_data["image_paths"], |
| "text": group_data["texts"] |
| }) |
| |
| |
| processed_group = group_dataset.map( |
| lambda examples: encode_to_latents( |
| [Image.open(path) for path in examples["image_path"]], |
| examples["text"] |
| ), |
| batched=True, |
| batch_size=batch_size, |
| |
| desc=f"Обработка группы размера {size_key[0]}x{size_key[1]}" |
| ) |
| |
| |
| group_save_path = f"{save_path}_temp/chunk_{chunk_idx}_size_{size_key[0]}x{size_key[1]}" |
| processed_group.save_to_disk(group_save_path) |
| clear_cuda_memory() |
| elapsed = time.time() - start_time |
| processed = (chunk_idx - 1) * chunk_size + sum([len(sg["image_paths"]) for sg in list(size_groups.values())[:list(size_groups.values()).index(group_data) + 1]]) |
| if processed > 0: |
| remaining = (elapsed / processed) * (total_files - processed) |
| elapsed_str = str(timedelta(seconds=int(elapsed))) |
| remaining_str = str(timedelta(seconds=int(remaining))) |
| print(f"ETA: Прошло {elapsed_str}, Осталось {remaining_str}, Прогресс {processed}/{total_files} ({processed/total_files:.1%})") |
|
|
| |
| def combine_chunks(temp_path, final_path): |
| """Объединение обработанных чанков в финальный датасет""" |
| chunks = sorted([ |
| os.path.join(temp_path, d) |
| for d in os.listdir(temp_path) |
| if d.startswith("chunk_") |
| ]) |
| |
| datasets = [load_from_disk(chunk) for chunk in chunks] |
| combined = concatenate_datasets(datasets) |
| combined.save_to_disk(final_path) |
| |
| print(f"✅ Датасет успешно сохранен в: {final_path}") |
|
|
| |
|
|
| |
| temp_path = f"{save_path}_temp" |
| os.makedirs(temp_path, exist_ok=True) |
|
|
| |
| image_paths, text_paths, width, height = process_folder(folder_path,limit) |
| print(f"Всего найдено {len(image_paths)} изображений") |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| process_in_chunks(image_paths, text_paths, width, height, chunk_size=20000, batch_size=batch_size) |
|
|
| |
| try: |
| shutil.rmtree(folder_path) |
| print(f"✅ Папка {folder_path} успешно удалена") |
| except Exception as e: |
| print(f"⚠️ Ошибка при удалении папки: {e}") |
| |
| |
| combine_chunks(temp_path, save_path) |
|
|
| |
| try: |
| shutil.rmtree(temp_path) |
| print(f"✅ Временная папка {temp_path} успешно удалена") |
| except Exception as e: |
| print(f"⚠️ Ошибка при удалении временной папки: {e}") |