| | |
| | import torch |
| | import os |
| | import gc |
| | import numpy as np |
| | import random |
| | import json |
| | import shutil |
| | import time |
| |
|
| | from datasets import Dataset, load_from_disk, concatenate_datasets |
| | from diffusers import AutoencoderKL,AutoencoderKLWan,AsymmetricAutoencoderKL,AutoencoderKLFlux2 |
| | from torchvision.transforms import Resize, ToTensor, Normalize, Compose, InterpolationMode, Lambda |
| | from transformers import AutoModel, AutoImageProcessor, AutoTokenizer, AutoModelForCausalLM |
| | from typing import Dict, List, Tuple, Optional, Any |
| | from PIL import Image |
| | from tqdm import tqdm |
| | from datetime import timedelta |
| |
|
| | |
| | dtype = torch.float16 |
| | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| | batch_size = 5 |
| | min_size = 640 |
| | max_size = 1280 |
| | step = 64 |
| | empty_share = 0.0 |
| | limit = 0 |
| | |
| | folder_path = "/workspace/ds" |
| | save_path = "/workspace/ds1234_flux32" |
| | os.makedirs(save_path, exist_ok=True) |
| |
|
| | |
| | def clear_cuda_memory(): |
| | if torch.cuda.is_available(): |
| | used_gb = torch.cuda.max_memory_allocated() / 1024**3 |
| | print(f"used_gb: {used_gb:.2f} GB") |
| | torch.cuda.empty_cache() |
| | gc.collect() |
| |
|
| | |
| | def load_models(): |
| | print("Загрузка моделей...") |
| | |
| | vae = AutoencoderKL.from_pretrained("vae", torch_dtype=dtype).to(device).eval() |
| | return vae |
| |
|
| | vae = load_models() |
| |
|
| | shift_factor = getattr(vae.config, "shift_factor", 0.0) |
| | if shift_factor is None: |
| | shift_factor = 0.0 |
| |
|
| | scaling_factor = getattr(vae.config, "scaling_factor", 1.0) |
| | if scaling_factor is None: |
| | scaling_factor = 1.0 |
| | |
| | mean = getattr(vae.config, "latents_mean", None) |
| | std = getattr(vae.config, "latents_std", None) |
| | if mean is not None and std is not None: |
| | latents_std = torch.tensor(std, device=device, dtype=dtype).view(1, len(std), 1, 1) |
| | latents_mean = torch.tensor(mean, device=device, dtype=dtype).view(1, len(mean), 1, 1) |
| | |
| | |
| | def get_image_transform(min_size=256, max_size=512, step=64): |
| | def transform(img, dry_run=False): |
| | |
| | original_width, original_height = img.size |
| |
|
| | |
| | if original_width >= original_height: |
| | new_width = max_size |
| | new_height = int(max_size * original_height / original_width) |
| | else: |
| | new_height = max_size |
| | new_width = int(max_size * original_width / original_height) |
| |
|
| | if new_height < min_size or new_width < min_size: |
| | |
| | if original_width <= original_height: |
| | new_width = min_size |
| | new_height = int(min_size * original_height / original_width) |
| | else: |
| | new_height = min_size |
| | new_width = int(min_size * original_width / original_height) |
| |
|
| | |
| | crop_width = min(max_size, (new_width // step) * step) |
| | crop_height = min(max_size, (new_height // step) * step) |
| |
|
| | |
| | crop_width = max(min_size, crop_width) |
| | crop_height = max(min_size, crop_height) |
| | |
| | |
| | if dry_run: |
| | return crop_width, crop_height |
| | |
| | |
| | img_resized = img.convert("RGB").resize((new_width, new_height), Image.LANCZOS) |
| | |
| | |
| | top = (new_height - crop_height) // 3 |
| | left = 0 |
| | |
| | |
| | img_cropped = img_resized.crop((left, top, left + crop_width, top + crop_height)) |
| | |
| | |
| | final_width, final_height = img_cropped.size |
| | |
| | |
| | img_tensor = ToTensor()(img_cropped) |
| | img_tensor = Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])(img_tensor) |
| | return img_tensor, img_cropped, final_width, final_height |
| |
|
| | return transform |
| |
|
| | |
| | def clean_label(label): |
| | label = label.replace("Image 1", "").replace("Image 2", "").replace("Image 3", "").replace("Image 4", "").replace("The image depicts ","").replace("The image presents ","").replace("The image features ","").replace("The image portrays ","").replace("The image is ","").strip() |
| | if label.startswith("."): |
| | label = label[1:].lstrip() |
| | return label |
| |
|
| | def process_labels_for_guidance(original_labels, prob_to_make_empty=0.01): |
| | """ |
| | Обрабатывает список меток для classifier-free guidance. |
| | |
| | С вероятностью prob_to_make_empty: |
| | - Метка в первом списке заменяется на пустую строку. |
| | - К метке во втором списке добавляется префикс "zero:". |
| | |
| | В противном случае метки в обоих списках остаются оригинальными. |
| | |
| | """ |
| | labels_for_model = [] |
| | labels_for_logging = [] |
| |
|
| | for label in original_labels: |
| | if random.random() < prob_to_make_empty: |
| | labels_for_model.append("") |
| | labels_for_logging.append(f"zero: {label}") |
| | else: |
| | labels_for_model.append(label) |
| | labels_for_logging.append(label) |
| |
|
| | return labels_for_model, labels_for_logging |
| |
|
| |
|
| | def encode_to_latents(images, texts): |
| | transform = get_image_transform(min_size, max_size, step) |
| | |
| | try: |
| | |
| | transformed_tensors = [] |
| | pil_images = [] |
| | widths, heights = [], [] |
| | |
| | |
| | for img in images: |
| | try: |
| | t_img, pil_img, w, h = transform(img) |
| | transformed_tensors.append(t_img) |
| | pil_images.append(pil_img) |
| | widths.append(w) |
| | heights.append(h) |
| | except Exception as e: |
| | print(f"Ошибка трансформации: {e}") |
| | continue |
| |
|
| | if not transformed_tensors: |
| | return None |
| |
|
| | |
| | batch_tensor = torch.stack(transformed_tensors).to(device, dtype) |
| | if batch_tensor.ndim==5: |
| | batch_tensor = batch_tensor.unsqueeze(2) |
| | |
| | |
| | with torch.no_grad(): |
| | posteriors = vae.encode(batch_tensor).latent_dist.mode() |
| | if latents_mean is not None and latents_std is not None: |
| | posteriors = (posteriors - latents_mean) / latents_std |
| | posteriors = (posteriors - shift_factor) / scaling_factor |
| | |
| | latents_np = posteriors.to(dtype).cpu().numpy() |
| |
|
| | |
| | text_labels = [clean_label(text) for text in texts] |
| |
|
| | model_prompts, text_labels = process_labels_for_guidance(text_labels, empty_share) |
| |
|
| | return { |
| | "vae": latents_np, |
| | "text": text_labels, |
| | "width": widths, |
| | "height": heights |
| | } |
| | |
| | except Exception as e: |
| | print(f"Критическая ошибка в encode_to_latents: {e}") |
| | raise |
| | |
| |
|
| | |
| | def process_folder(folder_path, limit=None): |
| | """ |
| | Рекурсивно обходит указанную директорию и все вложенные директории, |
| | собирая пути к изображениям и соответствующим текстовым файлам. |
| | """ |
| | image_paths = [] |
| | text_paths = [] |
| | width = [] |
| | height = [] |
| | transform = get_image_transform(min_size, max_size, step) |
| | |
| | |
| | for root, dirs, files in os.walk(folder_path): |
| | for filename in files: |
| | |
| | if filename.lower().endswith((".jpg", ".jpeg", ".png")): |
| | image_path = os.path.join(root, filename) |
| | try: |
| | img = Image.open(image_path) |
| | except Exception as e: |
| | print(f"Ошибка при открытии {image_path}: {e}") |
| | os.remove(image_path) |
| | text_path = os.path.splitext(image_path)[0] + ".txt" |
| | if os.path.exists(text_path): |
| | os.remove(text_path) |
| | continue |
| | |
| | w, h = transform(img, dry_run=True) |
| | |
| | text_path = os.path.splitext(image_path)[0] + ".txt" |
| | |
| | |
| | if os.path.exists(text_path) and min(w, h)>0: |
| | image_paths.append(image_path) |
| | text_paths.append(text_path) |
| | width.append(w) |
| | height.append(h) |
| | |
| | |
| | if limit and limit>0 and len(image_paths) >= limit: |
| | print(f"Достигнут лимит в {limit} изображений") |
| | return image_paths, text_paths, width, height |
| | |
| | print(f"Найдено {len(image_paths)} изображений с текстовыми описаниями") |
| | return image_paths, text_paths, width, height |
| | |
| | def process_in_chunks(image_paths, text_paths, width, height, chunk_size=10000, batch_size=1): |
| | total_files = len(image_paths) |
| | start_time = time.time() |
| | chunks = range(0, total_files, chunk_size) |
| | |
| | for chunk_idx, start in enumerate(chunks, 1): |
| | end = min(start + chunk_size, total_files) |
| | chunk_image_paths = image_paths[start:end] |
| | chunk_text_paths = text_paths[start:end] |
| | chunk_widths = width[start:end] if isinstance(width, list) else [width] * len(chunk_image_paths) |
| | chunk_heights = height[start:end] if isinstance(height, list) else [height] * len(chunk_image_paths) |
| | |
| | |
| | chunk_texts = [] |
| | for text_path in chunk_text_paths: |
| | try: |
| | with open(text_path, 'r', encoding='utf-8') as f: |
| | text = f.read().strip() |
| | chunk_texts.append(text) |
| | except Exception as e: |
| | print(f"Ошибка чтения {text_path}: {e}") |
| | chunk_texts.append("") |
| | |
| | |
| | size_groups = {} |
| | for i in range(len(chunk_image_paths)): |
| | size_key = (chunk_widths[i], chunk_heights[i]) |
| | if size_key not in size_groups: |
| | size_groups[size_key] = {"image_paths": [], "texts": []} |
| | size_groups[size_key]["image_paths"].append(chunk_image_paths[i]) |
| | size_groups[size_key]["texts"].append(chunk_texts[i]) |
| | |
| | |
| | for size_key, group_data in size_groups.items(): |
| | print(f"Обработка группы с размером {size_key[0]}x{size_key[1]} - {len(group_data['image_paths'])} изображений") |
| | |
| | group_dataset = Dataset.from_dict({ |
| | "image_path": group_data["image_paths"], |
| | "text": group_data["texts"] |
| | }) |
| | |
| | |
| | processed_group = group_dataset.map( |
| | lambda examples: encode_to_latents( |
| | [Image.open(path) for path in examples["image_path"]], |
| | examples["text"] |
| | ), |
| | batched=True, |
| | batch_size=batch_size, |
| | |
| | desc=f"Обработка группы размера {size_key[0]}x{size_key[1]}" |
| | ) |
| | |
| | |
| | group_save_path = f"{save_path}_temp/chunk_{chunk_idx}_size_{size_key[0]}x{size_key[1]}" |
| | processed_group.save_to_disk(group_save_path) |
| | clear_cuda_memory() |
| | elapsed = time.time() - start_time |
| | processed = (chunk_idx - 1) * chunk_size + sum([len(sg["image_paths"]) for sg in list(size_groups.values())[:list(size_groups.values()).index(group_data) + 1]]) |
| | if processed > 0: |
| | remaining = (elapsed / processed) * (total_files - processed) |
| | elapsed_str = str(timedelta(seconds=int(elapsed))) |
| | remaining_str = str(timedelta(seconds=int(remaining))) |
| | print(f"ETA: Прошло {elapsed_str}, Осталось {remaining_str}, Прогресс {processed}/{total_files} ({processed/total_files:.1%})") |
| |
|
| | |
| | def combine_chunks(temp_path, final_path): |
| | """Объединение обработанных чанков в финальный датасет""" |
| | chunks = sorted([ |
| | os.path.join(temp_path, d) |
| | for d in os.listdir(temp_path) |
| | if d.startswith("chunk_") |
| | ]) |
| | |
| | datasets = [load_from_disk(chunk) for chunk in chunks] |
| | combined = concatenate_datasets(datasets) |
| | combined.save_to_disk(final_path) |
| | |
| | print(f"✅ Датасет успешно сохранен в: {final_path}") |
| |
|
| | |
| |
|
| | |
| | temp_path = f"{save_path}_temp" |
| | os.makedirs(temp_path, exist_ok=True) |
| |
|
| | |
| | image_paths, text_paths, width, height = process_folder(folder_path,limit) |
| | print(f"Всего найдено {len(image_paths)} изображений") |
| |
|
| | |
| | process_in_chunks(image_paths, text_paths, width, height, chunk_size=20000, batch_size=batch_size) |
| |
|
| | |
| | try: |
| | shutil.rmtree(folder_path) |
| | print(f"✅ Папка {folder_path} успешно удалена") |
| | except Exception as e: |
| | print(f"⚠️ Ошибка при удалении папки: {e}") |
| | |
| | |
| | combine_chunks(temp_path, save_path) |
| |
|
| | |
| | try: |
| | shutil.rmtree(temp_path) |
| | print(f"✅ Временная папка {temp_path} успешно удалена") |
| | except Exception as e: |
| | print(f"⚠️ Ошибка при удалении временной папки: {e}") |