2b / dataset.py
babkasotona's picture
Upload folder using huggingface_hub
58bb2b7 verified
# pip install flash-attn --no-build-isolation
import torch
import os
import gc
import numpy as np
import random
import json
import shutil
import time
from datasets import Dataset, load_from_disk, concatenate_datasets
from diffusers import AutoencoderKLQwenImage
from torchvision.transforms import Resize, ToTensor, Normalize, Compose, InterpolationMode, Lambda
from transformers import AutoModel, AutoImageProcessor, AutoTokenizer, AutoModelForCausalLM
from typing import Dict, List, Tuple, Optional, Any
from PIL import Image
from tqdm import tqdm
from datetime import timedelta
from accelerate import Accelerator
accelerator = Accelerator()
device = accelerator.device
is_main_process = accelerator.is_main_process
process_index = accelerator.process_index
num_processes = accelerator.num_processes
# ---------------- 1️⃣ Настройки ----------------
dtype = torch.float16
batch_size = 5
min_size = 320
max_size = 640
step = 32
empty_share = 0.0
limit = 0
folder_path = "/workspace/dataset/d23"
save_path = "/workspace/ds234_640_vae_qwen"
os.makedirs(save_path, exist_ok=True)
def clear_cuda_memory():
if torch.cuda.is_available():
used_gb = torch.cuda.max_memory_allocated() / 1024**3
print(f"[GPU {process_index}] used_gb: {used_gb:.2f} GB")
torch.cuda.empty_cache()
gc.collect()
# ---------------- 2️⃣ Загрузка моделей ----------------
def load_models():
print(f"[GPU {process_index}] Загрузка моделей...")
vae = AutoencoderKLQwenImage.from_pretrained("vae", torch_dtype=dtype).to(device).eval()
return vae
vae = load_models()
shift_factor = getattr(vae.config, "shift_factor", 0.0) or 0.0
scaling_factor = getattr(vae.config, "scaling_factor", 1.0) or 1.0
mean = getattr(vae.config, "latents_mean", None)
std = getattr(vae.config, "latents_std", None)
if mean is not None and std is not None:
latents_std = torch.tensor(std, device=device, dtype=dtype).view(1, len(std), 1, 1, 1)
latents_mean = torch.tensor(mean, device=device, dtype=dtype).view(1, len(mean), 1, 1, 1)
# ---------------- 3️⃣ Трансформации ----------------
def get_image_transform(min_size=256, max_size=512, step=64):
def transform(img, dry_run=False):
original_width, original_height = img.size
if original_width >= original_height:
new_width = max_size
new_height = int(max_size * original_height / original_width)
else:
new_height = max_size
new_width = int(max_size * original_width / original_height)
if new_height < min_size or new_width < min_size:
if original_width <= original_height:
new_width = min_size
new_height = int(min_size * original_height / original_width)
else:
new_height = min_size
new_width = int(min_size * original_width / original_height)
crop_width = min(max_size, (new_width // step) * step)
crop_height = min(max_size, (new_height // step) * step)
crop_width = max(min_size, crop_width)
crop_height = max(min_size, crop_height)
if dry_run:
return crop_width, crop_height
img_resized = img.convert("RGB").resize((new_width, new_height), Image.LANCZOS)
top = (new_height - crop_height) // 3
left = 0
img_cropped = img_resized.crop((left, top, left + crop_width, top + crop_height))
final_width, final_height = img_cropped.size
img_tensor = ToTensor()(img_cropped)
img_tensor = Normalize(mean=[0.5]*3, std=[0.5]*3)(img_tensor)
return img_tensor, img_cropped, final_width, final_height
return transform
# ---------------- 4️⃣ Функции обработки ----------------
def clean_label(label):
label = label.replace("Image 1","").replace("Image 2","").replace("Image 3","").replace("Image 4","")
label = label.replace("The image depicts ","").replace("The image presents ","")
label = label.replace("The image features ","").replace("The image portrays ","").replace("The image is ","").strip()
if label.startswith("."):
label = label[1:].lstrip()
return label
def process_labels_for_guidance(original_labels, prob_to_make_empty=0.01):
labels_for_model = []
labels_for_logging = []
for label in original_labels:
if random.random() < prob_to_make_empty:
labels_for_model.append("")
labels_for_logging.append(f"zero: {label}")
else:
labels_for_model.append(label)
labels_for_logging.append(label)
return labels_for_model, labels_for_logging
def encode_to_latents(images, texts):
transform = get_image_transform(min_size, max_size, step)
transformed_tensors = []
widths, heights = [], []
for img in images:
try:
t_img, _, w, h = transform(img)
transformed_tensors.append(t_img)
widths.append(w)
heights.append(h)
except Exception as e:
print(f"Ошибка трансформации: {e}")
if not transformed_tensors:
return None
batch_tensor = torch.stack(transformed_tensors).to(device, dtype)
if batch_tensor.ndim==4:
batch_tensor = batch_tensor.unsqueeze(2)
with torch.no_grad():
posteriors = vae.encode(batch_tensor).latent_dist.mode()
if mean is not None and std is not None:
posteriors = (posteriors - latents_mean) / latents_std
posteriors = (posteriors - shift_factor) / scaling_factor
#latents_np = posteriors.cpu().numpy()
latents_np = posteriors.squeeze(2).cpu().numpy()
text_labels = [clean_label(text) for text in texts]
_, text_labels = process_labels_for_guidance(text_labels, empty_share)
return {
"vae": latents_np,
"text": text_labels,
"width": widths,
"height": heights
}
# ---------------- 5️⃣ Обработка папки ----------------
def process_folder(folder_path, limit=None):
image_paths, text_paths, width, height = [], [], [], []
transform = get_image_transform(min_size, max_size, step)
for root, _, files in os.walk(folder_path):
for filename in files:
if filename.lower().endswith((".jpg",".jpeg",".png")):
image_path = os.path.join(root, filename)
try:
img = Image.open(image_path)
except:
continue
w,h = transform(img, dry_run=True)
text_path = os.path.splitext(image_path)[0]+".txt"
if os.path.exists(text_path):
image_paths.append(image_path)
text_paths.append(text_path)
width.append(w)
height.append(h)
print(f"Найдено {len(image_paths)} изображений")
return image_paths, text_paths, width, height
def process_in_chunks(image_paths, text_paths, width, height, chunk_size=10000, batch_size=1):
total_files = len(image_paths)
start_time = time.time()
for chunk_idx, start in enumerate(range(0,total_files,chunk_size),1):
end = min(start+chunk_size,total_files)
chunk_image_paths = image_paths[start:end]
chunk_text_paths = text_paths[start:end]
chunk_widths = width[start:end]
chunk_heights = height[start:end]
chunk_texts = []
for text_path in chunk_text_paths:
try:
with open(text_path,'r',encoding='utf-8') as f:
chunk_texts.append(f.read().strip())
except:
chunk_texts.append("")
size_groups = {}
for i in range(len(chunk_image_paths)):
key=(chunk_widths[i],chunk_heights[i])
size_groups.setdefault(key,{"image_paths":[],"texts":[]})
size_groups[key]["image_paths"].append(chunk_image_paths[i])
size_groups[key]["texts"].append(chunk_texts[i])
for size_key,group_data in size_groups.items():
group_dataset = Dataset.from_dict(group_data)
processed_group = group_dataset.map(
lambda ex: encode_to_latents(
[Image.open(p) for p in ex["image_paths"]],
#[Image.open(p).convert("RGB") for p in ex["image_paths"]], # <--- Добавил .convert("RGB"), чтобы картинка загрузилась в память
ex["texts"]
),
batched=True,
batch_size=batch_size,
)
# --- NEW: уникальный путь ---
group_save_path = f"{save_path}_temp/chunk_{chunk_idx}_{size_key[0]}x{size_key[1]}_proc_{process_index}_"
# --- END NEW ---
processed_group.save_to_disk(group_save_path)
clear_cuda_memory()
# ---------------- 7️⃣ Объединение ----------------
def combine_chunks(temp_path, final_path):
chunks = sorted([
os.path.join(temp_path,d)
for d in os.listdir(temp_path)
if "chunk_" in d
])
datasets = [load_from_disk(c) for c in chunks]
combined = concatenate_datasets(datasets)
combined.save_to_disk(final_path)
print("✅ Сохранено")
# ---------------- MAIN ----------------
temp_path = f"{save_path}_temp"
os.makedirs(temp_path, exist_ok=True)
image_paths, text_paths, width, height = process_folder(folder_path,limit)
# сортировка
sorted_indices = sorted(range(len(width)), key=lambda i:(width[i],height[i]))
image_paths = [image_paths[i] for i in sorted_indices]
text_paths = [text_paths[i] for i in sorted_indices]
width = [width[i] for i in sorted_indices]
height = [height[i] for i in sorted_indices]
# --- shard по GPU ---
indices = list(range(len(image_paths)))
indices = indices[process_index::num_processes]
image_paths = [image_paths[i] for i in indices]
text_paths = [text_paths[i] for i in indices]
width = [width[i] for i in indices]
height = [height[i] for i in indices]
print(f"[GPU {process_index}] обрабатывает {len(image_paths)} файлов")
process_in_chunks(image_paths, text_paths, width, height, chunk_size=5000, batch_size=batch_size)
accelerator.wait_for_everyone()
# --- NEW: только главный процесс ---
if is_main_process:
#try:
#shutil.rmtree(folder_path)
#except:
# pass
combine_chunks(temp_path, save_path)
try:
shutil.rmtree(temp_path)
except:
pass