# pip install flash-attn --no-build-isolation import torch import os import gc import numpy as np import random import json import shutil import time from datasets import Dataset, load_from_disk, concatenate_datasets from diffusers import AutoencoderKLQwenImage from torchvision.transforms import Resize, ToTensor, Normalize, Compose, InterpolationMode, Lambda from transformers import AutoModel, AutoImageProcessor, AutoTokenizer, AutoModelForCausalLM from typing import Dict, List, Tuple, Optional, Any from PIL import Image from tqdm import tqdm from datetime import timedelta from accelerate import Accelerator accelerator = Accelerator() device = accelerator.device is_main_process = accelerator.is_main_process process_index = accelerator.process_index num_processes = accelerator.num_processes # ---------------- 1️⃣ Настройки ---------------- dtype = torch.float16 batch_size = 5 min_size = 320 max_size = 640 step = 32 empty_share = 0.0 limit = 0 folder_path = "/workspace/dataset/d23" save_path = "/workspace/ds234_640_vae_qwen" os.makedirs(save_path, exist_ok=True) def clear_cuda_memory(): if torch.cuda.is_available(): used_gb = torch.cuda.max_memory_allocated() / 1024**3 print(f"[GPU {process_index}] used_gb: {used_gb:.2f} GB") torch.cuda.empty_cache() gc.collect() # ---------------- 2️⃣ Загрузка моделей ---------------- def load_models(): print(f"[GPU {process_index}] Загрузка моделей...") vae = AutoencoderKLQwenImage.from_pretrained("vae", torch_dtype=dtype).to(device).eval() return vae vae = load_models() shift_factor = getattr(vae.config, "shift_factor", 0.0) or 0.0 scaling_factor = getattr(vae.config, "scaling_factor", 1.0) or 1.0 mean = getattr(vae.config, "latents_mean", None) std = getattr(vae.config, "latents_std", None) if mean is not None and std is not None: latents_std = torch.tensor(std, device=device, dtype=dtype).view(1, len(std), 1, 1, 1) latents_mean = torch.tensor(mean, device=device, dtype=dtype).view(1, len(mean), 1, 1, 1) # ---------------- 3️⃣ Трансформации ---------------- def get_image_transform(min_size=256, max_size=512, step=64): def transform(img, dry_run=False): original_width, original_height = img.size if original_width >= original_height: new_width = max_size new_height = int(max_size * original_height / original_width) else: new_height = max_size new_width = int(max_size * original_width / original_height) if new_height < min_size or new_width < min_size: if original_width <= original_height: new_width = min_size new_height = int(min_size * original_height / original_width) else: new_height = min_size new_width = int(min_size * original_width / original_height) crop_width = min(max_size, (new_width // step) * step) crop_height = min(max_size, (new_height // step) * step) crop_width = max(min_size, crop_width) crop_height = max(min_size, crop_height) if dry_run: return crop_width, crop_height img_resized = img.convert("RGB").resize((new_width, new_height), Image.LANCZOS) top = (new_height - crop_height) // 3 left = 0 img_cropped = img_resized.crop((left, top, left + crop_width, top + crop_height)) final_width, final_height = img_cropped.size img_tensor = ToTensor()(img_cropped) img_tensor = Normalize(mean=[0.5]*3, std=[0.5]*3)(img_tensor) return img_tensor, img_cropped, final_width, final_height return transform # ---------------- 4️⃣ Функции обработки ---------------- def clean_label(label): label = label.replace("Image 1","").replace("Image 2","").replace("Image 3","").replace("Image 4","") label = label.replace("The image depicts ","").replace("The image presents ","") label = label.replace("The image features ","").replace("The image portrays ","").replace("The image is ","").strip() if label.startswith("."): label = label[1:].lstrip() return label def process_labels_for_guidance(original_labels, prob_to_make_empty=0.01): labels_for_model = [] labels_for_logging = [] for label in original_labels: if random.random() < prob_to_make_empty: labels_for_model.append("") labels_for_logging.append(f"zero: {label}") else: labels_for_model.append(label) labels_for_logging.append(label) return labels_for_model, labels_for_logging def encode_to_latents(images, texts): transform = get_image_transform(min_size, max_size, step) transformed_tensors = [] widths, heights = [], [] for img in images: try: t_img, _, w, h = transform(img) transformed_tensors.append(t_img) widths.append(w) heights.append(h) except Exception as e: print(f"Ошибка трансформации: {e}") if not transformed_tensors: return None batch_tensor = torch.stack(transformed_tensors).to(device, dtype) if batch_tensor.ndim==4: batch_tensor = batch_tensor.unsqueeze(2) with torch.no_grad(): posteriors = vae.encode(batch_tensor).latent_dist.mode() if mean is not None and std is not None: posteriors = (posteriors - latents_mean) / latents_std posteriors = (posteriors - shift_factor) / scaling_factor #latents_np = posteriors.cpu().numpy() latents_np = posteriors.squeeze(2).cpu().numpy() text_labels = [clean_label(text) for text in texts] _, text_labels = process_labels_for_guidance(text_labels, empty_share) return { "vae": latents_np, "text": text_labels, "width": widths, "height": heights } # ---------------- 5️⃣ Обработка папки ---------------- def process_folder(folder_path, limit=None): image_paths, text_paths, width, height = [], [], [], [] transform = get_image_transform(min_size, max_size, step) for root, _, files in os.walk(folder_path): for filename in files: if filename.lower().endswith((".jpg",".jpeg",".png")): image_path = os.path.join(root, filename) try: img = Image.open(image_path) except: continue w,h = transform(img, dry_run=True) text_path = os.path.splitext(image_path)[0]+".txt" if os.path.exists(text_path): image_paths.append(image_path) text_paths.append(text_path) width.append(w) height.append(h) print(f"Найдено {len(image_paths)} изображений") return image_paths, text_paths, width, height def process_in_chunks(image_paths, text_paths, width, height, chunk_size=10000, batch_size=1): total_files = len(image_paths) start_time = time.time() for chunk_idx, start in enumerate(range(0,total_files,chunk_size),1): end = min(start+chunk_size,total_files) chunk_image_paths = image_paths[start:end] chunk_text_paths = text_paths[start:end] chunk_widths = width[start:end] chunk_heights = height[start:end] chunk_texts = [] for text_path in chunk_text_paths: try: with open(text_path,'r',encoding='utf-8') as f: chunk_texts.append(f.read().strip()) except: chunk_texts.append("") size_groups = {} for i in range(len(chunk_image_paths)): key=(chunk_widths[i],chunk_heights[i]) size_groups.setdefault(key,{"image_paths":[],"texts":[]}) size_groups[key]["image_paths"].append(chunk_image_paths[i]) size_groups[key]["texts"].append(chunk_texts[i]) for size_key,group_data in size_groups.items(): group_dataset = Dataset.from_dict(group_data) processed_group = group_dataset.map( lambda ex: encode_to_latents( [Image.open(p) for p in ex["image_paths"]], #[Image.open(p).convert("RGB") for p in ex["image_paths"]], # <--- Добавил .convert("RGB"), чтобы картинка загрузилась в память ex["texts"] ), batched=True, batch_size=batch_size, ) # --- NEW: уникальный путь --- group_save_path = f"{save_path}_temp/chunk_{chunk_idx}_{size_key[0]}x{size_key[1]}_proc_{process_index}_" # --- END NEW --- processed_group.save_to_disk(group_save_path) clear_cuda_memory() # ---------------- 7️⃣ Объединение ---------------- def combine_chunks(temp_path, final_path): chunks = sorted([ os.path.join(temp_path,d) for d in os.listdir(temp_path) if "chunk_" in d ]) datasets = [load_from_disk(c) for c in chunks] combined = concatenate_datasets(datasets) combined.save_to_disk(final_path) print("✅ Сохранено") # ---------------- MAIN ---------------- temp_path = f"{save_path}_temp" os.makedirs(temp_path, exist_ok=True) image_paths, text_paths, width, height = process_folder(folder_path,limit) # сортировка sorted_indices = sorted(range(len(width)), key=lambda i:(width[i],height[i])) image_paths = [image_paths[i] for i in sorted_indices] text_paths = [text_paths[i] for i in sorted_indices] width = [width[i] for i in sorted_indices] height = [height[i] for i in sorted_indices] # --- shard по GPU --- indices = list(range(len(image_paths))) indices = indices[process_index::num_processes] image_paths = [image_paths[i] for i in indices] text_paths = [text_paths[i] for i in indices] width = [width[i] for i in indices] height = [height[i] for i in indices] print(f"[GPU {process_index}] обрабатывает {len(image_paths)} файлов") process_in_chunks(image_paths, text_paths, width, height, chunk_size=5000, batch_size=batch_size) accelerator.wait_for_everyone() # --- NEW: только главный процесс --- if is_main_process: #try: #shutil.rmtree(folder_path) #except: # pass combine_chunks(temp_path, save_path) try: shutil.rmtree(temp_path) except: pass