Instructions to use babkasotona/2b with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Diffusers
How to use babkasotona/2b with Diffusers:
pip install -U diffusers transformers accelerate
import torch from diffusers import DiffusionPipeline # switch to "mps" for apple devices pipe = DiffusionPipeline.from_pretrained("babkasotona/2b", dtype=torch.bfloat16, device_map="cuda") prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" image = pipe(prompt).images[0] - Notebooks
- Google Colab
- Kaggle
init
Browse files- .gitattributes +5 -7
- .gitignore +0 -8
- dataset-sdxs2b-1152.py +300 -0
- dataset-sdxs2b-640.py +285 -0
- dataset_sample.ipynb +3 -170
- pipeline_sdxs.py +106 -207
- refined.jpg +3 -0
- scheduler/.ipynb_checkpoints/scheduler_config-checkpoint.json +0 -22
- scheduler/scheduler_config.json +3 -22
- test.ipynb +2 -2
- train-sdxs2b.py +33 -10
- transformer/diffusion_pytorch_model.safetensors +2 -2
- wandb/debug-internal.log +0 -0
- wandb/debug-internal.log +1 -0
- wandb/debug.log +0 -19
- wandb/debug.log +1 -0
.gitattributes
CHANGED
|
@@ -33,10 +33,8 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
wandb/run-20260504_065935-dzvbyo3j/run-dzvbyo3j.wandb filter=lfs diff=lfs merge=lfs -text
|
| 42 |
-
wandb/run-20260505_075313-ti70f47q/run-ti70f47q.wandb filter=lfs diff=lfs merge=lfs -text
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
*.jpg filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
*.ipynb filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
*.json filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
media/refined.webp filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
.gitignore
CHANGED
|
@@ -11,11 +11,3 @@ datasets
|
|
| 11 |
test
|
| 12 |
wandb
|
| 13 |
nohup.out
|
| 14 |
-
samples/
|
| 15 |
-
transformer/
|
| 16 |
-
*.jpg
|
| 17 |
-
*.png
|
| 18 |
-
datasets/
|
| 19 |
-
samples/
|
| 20 |
-
*.jpg
|
| 21 |
-
train.py
|
|
|
|
| 11 |
test
|
| 12 |
wandb
|
| 13 |
nohup.out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dataset-sdxs2b-1152.py
ADDED
|
@@ -0,0 +1,300 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# pip install flash-attn --no-build-isolation
|
| 2 |
+
import torch
|
| 3 |
+
import os
|
| 4 |
+
import gc
|
| 5 |
+
import numpy as np
|
| 6 |
+
import random
|
| 7 |
+
import json
|
| 8 |
+
import shutil
|
| 9 |
+
import time
|
| 10 |
+
|
| 11 |
+
from datasets import Dataset, load_from_disk, concatenate_datasets
|
| 12 |
+
from diffusers import AutoencoderKLQwenImage
|
| 13 |
+
from torchvision.transforms import Resize, ToTensor, Normalize, Compose, InterpolationMode, Lambda
|
| 14 |
+
from transformers import AutoModel, AutoImageProcessor, AutoTokenizer, AutoModelForCausalLM
|
| 15 |
+
from typing import Dict, List, Tuple, Optional, Any
|
| 16 |
+
from PIL import Image
|
| 17 |
+
from tqdm import tqdm
|
| 18 |
+
from datetime import timedelta
|
| 19 |
+
from accelerate import Accelerator
|
| 20 |
+
|
| 21 |
+
accelerator = Accelerator()
|
| 22 |
+
device = accelerator.device
|
| 23 |
+
is_main_process = accelerator.is_main_process
|
| 24 |
+
process_index = accelerator.process_index
|
| 25 |
+
num_processes = accelerator.num_processes
|
| 26 |
+
|
| 27 |
+
# ---------------- 1️⃣ Настройки ----------------
|
| 28 |
+
dtype = torch.float16
|
| 29 |
+
batch_size = 5
|
| 30 |
+
min_size = 576
|
| 31 |
+
max_size = 1152
|
| 32 |
+
step = 64
|
| 33 |
+
empty_share = 0.0
|
| 34 |
+
limit = 0
|
| 35 |
+
|
| 36 |
+
folder_path = "/root/dataset"
|
| 37 |
+
save_path = "/root/ds1234_1152_vae_qwen"
|
| 38 |
+
os.makedirs(save_path, exist_ok=True)
|
| 39 |
+
|
| 40 |
+
def clear_cuda_memory():
|
| 41 |
+
if torch.cuda.is_available():
|
| 42 |
+
used_gb = torch.cuda.max_memory_allocated() / 1024**3
|
| 43 |
+
print(f"[GPU {process_index}] used_gb: {used_gb:.2f} GB")
|
| 44 |
+
torch.cuda.empty_cache()
|
| 45 |
+
gc.collect()
|
| 46 |
+
|
| 47 |
+
# ---------------- 2️⃣ Загрузка моделей ----------------
|
| 48 |
+
def load_models():
|
| 49 |
+
print(f"[GPU {process_index}] Загрузка моделей...")
|
| 50 |
+
vae = AutoencoderKLQwenImage.from_pretrained("vae", torch_dtype=dtype).to(device).eval()
|
| 51 |
+
return vae
|
| 52 |
+
|
| 53 |
+
vae = load_models()
|
| 54 |
+
|
| 55 |
+
shift_factor = getattr(vae.config, "shift_factor", 0.0) or 0.0
|
| 56 |
+
scaling_factor = getattr(vae.config, "scaling_factor", 1.0) or 1.0
|
| 57 |
+
|
| 58 |
+
mean = getattr(vae.config, "latents_mean", None)
|
| 59 |
+
std = getattr(vae.config, "latents_std", None)
|
| 60 |
+
if mean is not None and std is not None:
|
| 61 |
+
latents_std = torch.tensor(std, device=device, dtype=dtype).view(1, len(std), 1, 1, 1)
|
| 62 |
+
latents_mean = torch.tensor(mean, device=device, dtype=dtype).view(1, len(mean), 1, 1, 1)
|
| 63 |
+
|
| 64 |
+
# ---------------- 3️⃣ Трансформации ----------------
|
| 65 |
+
def get_image_transform(min_size=256, max_size=512, step=64):
|
| 66 |
+
def transform(img, dry_run=False):
|
| 67 |
+
original_width, original_height = img.size
|
| 68 |
+
|
| 69 |
+
if original_width >= original_height:
|
| 70 |
+
new_width = max_size
|
| 71 |
+
new_height = int(max_size * original_height / original_width)
|
| 72 |
+
else:
|
| 73 |
+
new_height = max_size
|
| 74 |
+
new_width = int(max_size * original_width / original_height)
|
| 75 |
+
|
| 76 |
+
if new_height < min_size or new_width < min_size:
|
| 77 |
+
if original_width <= original_height:
|
| 78 |
+
new_width = min_size
|
| 79 |
+
new_height = int(min_size * original_height / original_width)
|
| 80 |
+
else:
|
| 81 |
+
new_height = min_size
|
| 82 |
+
new_width = int(min_size * original_width / original_height)
|
| 83 |
+
|
| 84 |
+
crop_width = min(max_size, (new_width // step) * step)
|
| 85 |
+
crop_height = min(max_size, (new_height // step) * step)
|
| 86 |
+
|
| 87 |
+
crop_width = max(min_size, crop_width)
|
| 88 |
+
crop_height = max(min_size, crop_height)
|
| 89 |
+
|
| 90 |
+
if dry_run:
|
| 91 |
+
return crop_width, crop_height
|
| 92 |
+
|
| 93 |
+
img_resized = img.convert("RGB").resize((new_width, new_height), Image.LANCZOS)
|
| 94 |
+
|
| 95 |
+
top = (new_height - crop_height) // 3
|
| 96 |
+
left = 0
|
| 97 |
+
|
| 98 |
+
img_cropped = img_resized.crop((left, top, left + crop_width, top + crop_height))
|
| 99 |
+
|
| 100 |
+
final_width, final_height = img_cropped.size
|
| 101 |
+
|
| 102 |
+
img_tensor = ToTensor()(img_cropped)
|
| 103 |
+
img_tensor = Normalize(mean=[0.5]*3, std=[0.5]*3)(img_tensor)
|
| 104 |
+
return img_tensor, img_cropped, final_width, final_height
|
| 105 |
+
|
| 106 |
+
return transform
|
| 107 |
+
|
| 108 |
+
# ---------------- 4️⃣ Функции обработки ----------------
|
| 109 |
+
def clean_label(label):
|
| 110 |
+
label = label.replace("Image 1","").replace("Image 2","").replace("Image 3","").replace("Image 4","")
|
| 111 |
+
label = label.replace("The image depicts ","").replace("The image presents ","")
|
| 112 |
+
label = label.replace("The image features ","").replace("The image portrays ","").replace("The image is ","").strip()
|
| 113 |
+
if label.startswith("."):
|
| 114 |
+
label = label[1:].lstrip()
|
| 115 |
+
return label
|
| 116 |
+
|
| 117 |
+
def process_labels_for_guidance(original_labels, prob_to_make_empty=0.01):
|
| 118 |
+
labels_for_model = []
|
| 119 |
+
labels_for_logging = []
|
| 120 |
+
|
| 121 |
+
for label in original_labels:
|
| 122 |
+
if random.random() < prob_to_make_empty:
|
| 123 |
+
labels_for_model.append("")
|
| 124 |
+
labels_for_logging.append(f"zero: {label}")
|
| 125 |
+
else:
|
| 126 |
+
labels_for_model.append(label)
|
| 127 |
+
labels_for_logging.append(label)
|
| 128 |
+
|
| 129 |
+
return labels_for_model, labels_for_logging
|
| 130 |
+
|
| 131 |
+
def encode_to_latents(images, texts):
|
| 132 |
+
transform = get_image_transform(min_size, max_size, step)
|
| 133 |
+
|
| 134 |
+
transformed_tensors = []
|
| 135 |
+
widths, heights = [], []
|
| 136 |
+
|
| 137 |
+
for img in images:
|
| 138 |
+
try:
|
| 139 |
+
t_img, _, w, h = transform(img)
|
| 140 |
+
transformed_tensors.append(t_img)
|
| 141 |
+
widths.append(w)
|
| 142 |
+
heights.append(h)
|
| 143 |
+
except Exception as e:
|
| 144 |
+
print(f"Ошибка трансформации: {e}")
|
| 145 |
+
|
| 146 |
+
if not transformed_tensors:
|
| 147 |
+
return None
|
| 148 |
+
|
| 149 |
+
batch_tensor = torch.stack(transformed_tensors).to(device, dtype)
|
| 150 |
+
|
| 151 |
+
if batch_tensor.ndim==4:
|
| 152 |
+
batch_tensor = batch_tensor.unsqueeze(2)
|
| 153 |
+
|
| 154 |
+
with torch.no_grad():
|
| 155 |
+
posteriors = vae.encode(batch_tensor).latent_dist.mode()
|
| 156 |
+
if mean is not None and std is not None:
|
| 157 |
+
posteriors = (posteriors - latents_mean) / latents_std
|
| 158 |
+
posteriors = (posteriors - shift_factor) / scaling_factor
|
| 159 |
+
|
| 160 |
+
#latents_np = posteriors.cpu().numpy()
|
| 161 |
+
latents_np = posteriors.squeeze(2).cpu().numpy()
|
| 162 |
+
|
| 163 |
+
text_labels = [clean_label(text) for text in texts]
|
| 164 |
+
_, text_labels = process_labels_for_guidance(text_labels, empty_share)
|
| 165 |
+
|
| 166 |
+
return {
|
| 167 |
+
"vae": latents_np,
|
| 168 |
+
"text": text_labels,
|
| 169 |
+
"width": widths,
|
| 170 |
+
"height": heights
|
| 171 |
+
}
|
| 172 |
+
|
| 173 |
+
# ---------------- 5️⃣ Обработка папки ----------------
|
| 174 |
+
def process_folder(folder_path, limit=None):
|
| 175 |
+
image_paths, text_paths, width, height = [], [], [], []
|
| 176 |
+
transform = get_image_transform(min_size, max_size, step)
|
| 177 |
+
|
| 178 |
+
for root, _, files in os.walk(folder_path):
|
| 179 |
+
for filename in files:
|
| 180 |
+
if filename.lower().endswith((".jpg",".jpeg",".png",".webp")):
|
| 181 |
+
image_path = os.path.join(root, filename)
|
| 182 |
+
try:
|
| 183 |
+
img = Image.open(image_path)
|
| 184 |
+
except:
|
| 185 |
+
continue
|
| 186 |
+
|
| 187 |
+
w,h = transform(img, dry_run=True)
|
| 188 |
+
text_path = os.path.splitext(image_path)[0]+".txt"
|
| 189 |
+
|
| 190 |
+
if os.path.exists(text_path):
|
| 191 |
+
image_paths.append(image_path)
|
| 192 |
+
text_paths.append(text_path)
|
| 193 |
+
width.append(w)
|
| 194 |
+
height.append(h)
|
| 195 |
+
|
| 196 |
+
print(f"Найдено {len(image_paths)} изображений")
|
| 197 |
+
return image_paths, text_paths, width, height
|
| 198 |
+
|
| 199 |
+
def process_in_chunks(image_paths, text_paths, width, height, chunk_size=5000, batch_size=1):
|
| 200 |
+
total_files = len(image_paths)
|
| 201 |
+
start_time = time.time()
|
| 202 |
+
|
| 203 |
+
for chunk_idx, start in enumerate(range(0,total_files,chunk_size),1):
|
| 204 |
+
end = min(start+chunk_size,total_files)
|
| 205 |
+
|
| 206 |
+
chunk_image_paths = image_paths[start:end]
|
| 207 |
+
chunk_text_paths = text_paths[start:end]
|
| 208 |
+
chunk_widths = width[start:end]
|
| 209 |
+
chunk_heights = height[start:end]
|
| 210 |
+
|
| 211 |
+
chunk_texts = []
|
| 212 |
+
for text_path in chunk_text_paths:
|
| 213 |
+
try:
|
| 214 |
+
with open(text_path,'r',encoding='utf-8') as f:
|
| 215 |
+
chunk_texts.append(f.read().strip())
|
| 216 |
+
except:
|
| 217 |
+
chunk_texts.append("")
|
| 218 |
+
|
| 219 |
+
size_groups = {}
|
| 220 |
+
for i in range(len(chunk_image_paths)):
|
| 221 |
+
key=(chunk_widths[i],chunk_heights[i])
|
| 222 |
+
size_groups.setdefault(key,{"image_paths":[],"texts":[]})
|
| 223 |
+
size_groups[key]["image_paths"].append(chunk_image_paths[i])
|
| 224 |
+
size_groups[key]["texts"].append(chunk_texts[i])
|
| 225 |
+
|
| 226 |
+
for size_key,group_data in size_groups.items():
|
| 227 |
+
group_dataset = Dataset.from_dict(group_data)
|
| 228 |
+
|
| 229 |
+
processed_group = group_dataset.map(
|
| 230 |
+
lambda ex: encode_to_latents(
|
| 231 |
+
[Image.open(p) for p in ex["image_paths"]],
|
| 232 |
+
#[Image.open(p).convert("RGB") for p in ex["image_paths"]], # <--- Добавил .convert("RGB"), чтобы картинка загрузилась в память
|
| 233 |
+
ex["texts"]
|
| 234 |
+
),
|
| 235 |
+
batched=True,
|
| 236 |
+
batch_size=batch_size,
|
| 237 |
+
)
|
| 238 |
+
|
| 239 |
+
# --- NEW: уникальный путь ---
|
| 240 |
+
group_save_path = f"{save_path}_temp/chunk_{chunk_idx}_{size_key[0]}x{size_key[1]}_proc_{process_index}_"
|
| 241 |
+
# --- END NEW ---
|
| 242 |
+
|
| 243 |
+
processed_group.save_to_disk(group_save_path)
|
| 244 |
+
clear_cuda_memory()
|
| 245 |
+
|
| 246 |
+
# ---------------- 7️⃣ Объединение ----------------
|
| 247 |
+
def combine_chunks(temp_path, final_path):
|
| 248 |
+
chunks = sorted([
|
| 249 |
+
os.path.join(temp_path,d)
|
| 250 |
+
for d in os.listdir(temp_path)
|
| 251 |
+
if "chunk_" in d
|
| 252 |
+
])
|
| 253 |
+
|
| 254 |
+
datasets = [load_from_disk(c) for c in chunks]
|
| 255 |
+
combined = concatenate_datasets(datasets)
|
| 256 |
+
combined.save_to_disk(final_path)
|
| 257 |
+
|
| 258 |
+
print("✅ Сохранено")
|
| 259 |
+
|
| 260 |
+
# ---------------- MAIN ----------------
|
| 261 |
+
temp_path = f"{save_path}_temp"
|
| 262 |
+
os.makedirs(temp_path, exist_ok=True)
|
| 263 |
+
|
| 264 |
+
image_paths, text_paths, width, height = process_folder(folder_path,limit)
|
| 265 |
+
|
| 266 |
+
# сортировка
|
| 267 |
+
sorted_indices = sorted(range(len(width)), key=lambda i:(width[i],height[i]))
|
| 268 |
+
image_paths = [image_paths[i] for i in sorted_indices]
|
| 269 |
+
text_paths = [text_paths[i] for i in sorted_indices]
|
| 270 |
+
width = [width[i] for i in sorted_indices]
|
| 271 |
+
height = [height[i] for i in sorted_indices]
|
| 272 |
+
|
| 273 |
+
# --- shard по GPU ---
|
| 274 |
+
indices = list(range(len(image_paths)))
|
| 275 |
+
indices = indices[process_index::num_processes]
|
| 276 |
+
|
| 277 |
+
image_paths = [image_paths[i] for i in indices]
|
| 278 |
+
text_paths = [text_paths[i] for i in indices]
|
| 279 |
+
width = [width[i] for i in indices]
|
| 280 |
+
height = [height[i] for i in indices]
|
| 281 |
+
|
| 282 |
+
print(f"[GPU {process_index}] обрабатывает {len(image_paths)} файлов")
|
| 283 |
+
|
| 284 |
+
process_in_chunks(image_paths, text_paths, width, height, chunk_size=1000, batch_size=batch_size)
|
| 285 |
+
|
| 286 |
+
accelerator.wait_for_everyone()
|
| 287 |
+
|
| 288 |
+
# --- NEW: только главный процесс ---
|
| 289 |
+
if is_main_process:
|
| 290 |
+
try:
|
| 291 |
+
shutil.rmtree(folder_path)
|
| 292 |
+
except:
|
| 293 |
+
pass
|
| 294 |
+
|
| 295 |
+
combine_chunks(temp_path, save_path)
|
| 296 |
+
|
| 297 |
+
try:
|
| 298 |
+
shutil.rmtree(temp_path)
|
| 299 |
+
except:
|
| 300 |
+
pass
|
dataset-sdxs2b-640.py
ADDED
|
@@ -0,0 +1,285 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# pip install flash-attn --no-build-isolation
|
| 2 |
+
import torch
|
| 3 |
+
import os
|
| 4 |
+
import gc
|
| 5 |
+
import numpy as np
|
| 6 |
+
import random
|
| 7 |
+
import json
|
| 8 |
+
import shutil
|
| 9 |
+
import time
|
| 10 |
+
|
| 11 |
+
from datasets import Dataset, load_from_disk, concatenate_datasets
|
| 12 |
+
from diffusers import AutoencoderKLQwenImage
|
| 13 |
+
from torchvision.transforms import Resize, ToTensor, Normalize, Compose, InterpolationMode, Lambda
|
| 14 |
+
from transformers import AutoModel, AutoImageProcessor, AutoTokenizer, AutoModelForCausalLM
|
| 15 |
+
from typing import Dict, List, Tuple, Optional, Any
|
| 16 |
+
from PIL import Image
|
| 17 |
+
from tqdm import tqdm
|
| 18 |
+
from datetime import timedelta
|
| 19 |
+
from accelerate import Accelerator
|
| 20 |
+
|
| 21 |
+
accelerator = Accelerator()
|
| 22 |
+
device = accelerator.device
|
| 23 |
+
is_main_process = accelerator.is_main_process
|
| 24 |
+
process_index = accelerator.process_index
|
| 25 |
+
num_processes = accelerator.num_processes
|
| 26 |
+
|
| 27 |
+
# ---------------- 1️⃣ Настройки ----------------
|
| 28 |
+
dtype = torch.float16
|
| 29 |
+
batch_size = 5
|
| 30 |
+
min_size = 320
|
| 31 |
+
max_size = 640
|
| 32 |
+
step = 64
|
| 33 |
+
empty_share = 0.0
|
| 34 |
+
limit = 0
|
| 35 |
+
|
| 36 |
+
folder_path = "/root/datasets/butterfly"
|
| 37 |
+
save_path = "datasets/dsb_640_vae_qwen"
|
| 38 |
+
os.makedirs(save_path, exist_ok=True)
|
| 39 |
+
|
| 40 |
+
def clear_cuda_memory():
|
| 41 |
+
if torch.cuda.is_available():
|
| 42 |
+
used_gb = torch.cuda.max_memory_allocated() / 1024**3
|
| 43 |
+
print(f"[GPU {process_index}] used_gb: {used_gb:.2f} GB")
|
| 44 |
+
torch.cuda.empty_cache()
|
| 45 |
+
gc.collect()
|
| 46 |
+
|
| 47 |
+
# ---------------- 2️⃣ Загрузка моделей ----------------
|
| 48 |
+
def load_models():
|
| 49 |
+
print(f"[GPU {process_index}] Загрузка моделей...")
|
| 50 |
+
vae = AutoencoderKLQwenImage.from_pretrained("vae", torch_dtype=dtype).to(device).eval()
|
| 51 |
+
return vae
|
| 52 |
+
|
| 53 |
+
vae = load_models()
|
| 54 |
+
|
| 55 |
+
shift_factor = getattr(vae.config, "shift_factor", 0.0) or 0.0
|
| 56 |
+
scaling_factor = getattr(vae.config, "scaling_factor", 1.0) or 1.0
|
| 57 |
+
|
| 58 |
+
mean = getattr(vae.config, "latents_mean", None)
|
| 59 |
+
std = getattr(vae.config, "latents_std", None)
|
| 60 |
+
if mean is not None and std is not None:
|
| 61 |
+
latents_std = torch.tensor(std, device=device, dtype=dtype).view(1, len(std), 1, 1, 1)
|
| 62 |
+
latents_mean = torch.tensor(mean, device=device, dtype=dtype).view(1, len(mean), 1, 1, 1)
|
| 63 |
+
|
| 64 |
+
# ---------------- 3️⃣ Трансформации ----------------
|
| 65 |
+
def get_image_transform(min_size=256, max_size=512, step=64):
|
| 66 |
+
def transform(img, dry_run=False):
|
| 67 |
+
original_width, original_height = img.size
|
| 68 |
+
|
| 69 |
+
if original_width >= original_height:
|
| 70 |
+
new_width = max_size
|
| 71 |
+
new_height = int(max_size * original_height / original_width)
|
| 72 |
+
else:
|
| 73 |
+
new_height = max_size
|
| 74 |
+
new_width = int(max_size * original_width / original_height)
|
| 75 |
+
|
| 76 |
+
if new_height < min_size or new_width < min_size:
|
| 77 |
+
if original_width <= original_height:
|
| 78 |
+
new_width = min_size
|
| 79 |
+
new_height = int(min_size * original_height / original_width)
|
| 80 |
+
else:
|
| 81 |
+
new_height = min_size
|
| 82 |
+
new_width = int(min_size * original_width / original_height)
|
| 83 |
+
|
| 84 |
+
crop_width = min(max_size, (new_width // step) * step)
|
| 85 |
+
crop_height = min(max_size, (new_height // step) * step)
|
| 86 |
+
|
| 87 |
+
crop_width = max(min_size, crop_width)
|
| 88 |
+
crop_height = max(min_size, crop_height)
|
| 89 |
+
|
| 90 |
+
if dry_run:
|
| 91 |
+
return crop_width, crop_height
|
| 92 |
+
|
| 93 |
+
img_resized = img.convert("RGB").resize((new_width, new_height), Image.LANCZOS)
|
| 94 |
+
|
| 95 |
+
top = (new_height - crop_height) // 3
|
| 96 |
+
left = 0
|
| 97 |
+
|
| 98 |
+
img_cropped = img_resized.crop((left, top, left + crop_width, top + crop_height))
|
| 99 |
+
|
| 100 |
+
final_width, final_height = img_cropped.size
|
| 101 |
+
|
| 102 |
+
img_tensor = ToTensor()(img_cropped)
|
| 103 |
+
img_tensor = Normalize(mean=[0.5]*3, std=[0.5]*3)(img_tensor)
|
| 104 |
+
return img_tensor, img_cropped, final_width, final_height
|
| 105 |
+
|
| 106 |
+
return transform
|
| 107 |
+
|
| 108 |
+
# ---------------- 4️⃣ Функции обработки ----------------
|
| 109 |
+
def clean_label(label):
|
| 110 |
+
label = label.replace("Image 1","").replace("Image 2","").replace("Image 3","").replace("Image 4","")
|
| 111 |
+
label = label.replace("The image depicts ","").replace("The image presents ","")
|
| 112 |
+
label = label.replace("The image features ","").replace("The image portrays ","").replace("The image is ","").strip()
|
| 113 |
+
if label.startswith("."):
|
| 114 |
+
label = label[1:].lstrip()
|
| 115 |
+
return label
|
| 116 |
+
|
| 117 |
+
def process_labels_for_guidance(original_labels, prob_to_make_empty=0.01):
|
| 118 |
+
labels_for_model = []
|
| 119 |
+
labels_for_logging = []
|
| 120 |
+
|
| 121 |
+
for label in original_labels:
|
| 122 |
+
if random.random() < prob_to_make_empty:
|
| 123 |
+
labels_for_model.append("")
|
| 124 |
+
labels_for_logging.append(f"zero: {label}")
|
| 125 |
+
else:
|
| 126 |
+
labels_for_model.append(label)
|
| 127 |
+
labels_for_logging.append(label)
|
| 128 |
+
|
| 129 |
+
return labels_for_model, labels_for_logging
|
| 130 |
+
|
| 131 |
+
def encode_to_latents(images, texts):
|
| 132 |
+
transform = get_image_transform(min_size, max_size, step)
|
| 133 |
+
|
| 134 |
+
transformed_tensors = []
|
| 135 |
+
widths, heights = [], []
|
| 136 |
+
|
| 137 |
+
for img in images:
|
| 138 |
+
try:
|
| 139 |
+
t_img, _, w, h = transform(img)
|
| 140 |
+
transformed_tensors.append(t_img)
|
| 141 |
+
widths.append(w)
|
| 142 |
+
heights.append(h)
|
| 143 |
+
except Exception as e:
|
| 144 |
+
print(f"Ошибка трансформации: {e}")
|
| 145 |
+
|
| 146 |
+
if not transformed_tensors:
|
| 147 |
+
return None
|
| 148 |
+
|
| 149 |
+
batch_tensor = torch.stack(transformed_tensors).to(device, dtype)
|
| 150 |
+
|
| 151 |
+
if batch_tensor.ndim==4:
|
| 152 |
+
batch_tensor = batch_tensor.unsqueeze(2)
|
| 153 |
+
|
| 154 |
+
with torch.no_grad():
|
| 155 |
+
posteriors = vae.encode(batch_tensor).latent_dist.mode()
|
| 156 |
+
if mean is not None and std is not None:
|
| 157 |
+
posteriors = (posteriors - latents_mean) / latents_std
|
| 158 |
+
posteriors = (posteriors - shift_factor) / scaling_factor
|
| 159 |
+
|
| 160 |
+
#latents_np = posteriors.cpu().numpy()
|
| 161 |
+
latents_np = posteriors.squeeze(2).cpu().numpy()
|
| 162 |
+
|
| 163 |
+
text_labels = [clean_label(text) for text in texts]
|
| 164 |
+
_, text_labels = process_labels_for_guidance(text_labels, empty_share)
|
| 165 |
+
|
| 166 |
+
return {
|
| 167 |
+
"vae": latents_np,
|
| 168 |
+
"text": text_labels,
|
| 169 |
+
"width": widths,
|
| 170 |
+
"height": heights
|
| 171 |
+
}
|
| 172 |
+
|
| 173 |
+
# ---------------- 5️⃣ Обработка папки ----------------
|
| 174 |
+
def process_folder(folder_path, limit=None):
|
| 175 |
+
image_paths, text_paths, width, height = [], [], [], []
|
| 176 |
+
transform = get_image_transform(min_size, max_size, step)
|
| 177 |
+
|
| 178 |
+
for root, _, files in os.walk(folder_path):
|
| 179 |
+
for filename in files:
|
| 180 |
+
if filename.lower().endswith((".jpg",".jpeg",".png",".webp")):
|
| 181 |
+
image_path = os.path.join(root, filename)
|
| 182 |
+
try:
|
| 183 |
+
img = Image.open(image_path)
|
| 184 |
+
except:
|
| 185 |
+
continue
|
| 186 |
+
|
| 187 |
+
w,h = transform(img, dry_run=True)
|
| 188 |
+
text_path = os.path.splitext(image_path)[0]+".txt"
|
| 189 |
+
|
| 190 |
+
if os.path.exists(text_path):
|
| 191 |
+
image_paths.append(image_path)
|
| 192 |
+
text_paths.append(text_path)
|
| 193 |
+
width.append(w)
|
| 194 |
+
height.append(h)
|
| 195 |
+
|
| 196 |
+
print(f"Найдено {len(image_paths)} изображений")
|
| 197 |
+
return image_paths, text_paths, width, height
|
| 198 |
+
|
| 199 |
+
def process_in_chunks(image_paths, text_paths, width, height, chunk_size=5000, batch_size=1):
|
| 200 |
+
total_files = len(image_paths)
|
| 201 |
+
start_time = time.time()
|
| 202 |
+
|
| 203 |
+
for chunk_idx, start in enumerate(range(0,total_files,chunk_size),1):
|
| 204 |
+
end = min(start+chunk_size,total_files)
|
| 205 |
+
|
| 206 |
+
chunk_image_paths = image_paths[start:end]
|
| 207 |
+
chunk_text_paths = text_paths[start:end]
|
| 208 |
+
chunk_widths = width[start:end]
|
| 209 |
+
chunk_heights = height[start:end]
|
| 210 |
+
|
| 211 |
+
chunk_texts = []
|
| 212 |
+
for text_path in chunk_text_paths:
|
| 213 |
+
try:
|
| 214 |
+
with open(text_path,'r',encoding='utf-8') as f:
|
| 215 |
+
chunk_texts.append(f.read().strip())
|
| 216 |
+
except:
|
| 217 |
+
chunk_texts.append("")
|
| 218 |
+
|
| 219 |
+
size_groups = {}
|
| 220 |
+
for i in range(len(chunk_image_paths)):
|
| 221 |
+
key=(chunk_widths[i],chunk_heights[i])
|
| 222 |
+
size_groups.setdefault(key,{"image_paths":[],"texts":[]})
|
| 223 |
+
size_groups[key]["image_paths"].append(chunk_image_paths[i])
|
| 224 |
+
size_groups[key]["texts"].append(chunk_texts[i])
|
| 225 |
+
|
| 226 |
+
for size_key,group_data in size_groups.items():
|
| 227 |
+
group_dataset = Dataset.from_dict(group_data)
|
| 228 |
+
|
| 229 |
+
processed_group = group_dataset.map(
|
| 230 |
+
lambda ex: encode_to_latents(
|
| 231 |
+
[Image.open(p) for p in ex["image_paths"]],
|
| 232 |
+
#[Image.open(p).convert("RGB") for p in ex["image_paths"]], # <--- Добавил .convert("RGB"), чтобы картинка загрузилась в память
|
| 233 |
+
ex["texts"]
|
| 234 |
+
),
|
| 235 |
+
batched=True,
|
| 236 |
+
batch_size=batch_size,
|
| 237 |
+
)
|
| 238 |
+
|
| 239 |
+
# --- NEW: уникальный путь ---
|
| 240 |
+
group_save_path = f"{save_path}_temp/chunk_{chunk_idx}_{size_key[0]}x{size_key[1]}_proc_{process_index}_"
|
| 241 |
+
# --- END NEW ---
|
| 242 |
+
|
| 243 |
+
processed_group.save_to_disk(group_save_path)
|
| 244 |
+
clear_cuda_memory()
|
| 245 |
+
|
| 246 |
+
# ---------------- 7️⃣ Объединение ----------------
|
| 247 |
+
def combine_chunks(temp_path, final_path):
|
| 248 |
+
chunks = sorted([
|
| 249 |
+
os.path.join(temp_path,d)
|
| 250 |
+
for d in os.listdir(temp_path)
|
| 251 |
+
if "chunk_" in d
|
| 252 |
+
])
|
| 253 |
+
|
| 254 |
+
datasets = [load_from_disk(c) for c in chunks]
|
| 255 |
+
combined = concatenate_datasets(datasets)
|
| 256 |
+
combined.save_to_disk(final_path)
|
| 257 |
+
|
| 258 |
+
print("✅ Сохранено")
|
| 259 |
+
|
| 260 |
+
# ---------------- MAIN ----------------
|
| 261 |
+
temp_path = f"{save_path}_temp"
|
| 262 |
+
os.makedirs(temp_path, exist_ok=True)
|
| 263 |
+
|
| 264 |
+
image_paths, text_paths, width, height = process_folder(folder_path,limit)
|
| 265 |
+
|
| 266 |
+
# сортировка
|
| 267 |
+
sorted_indices = sorted(range(len(width)), key=lambda i:(width[i],height[i]))
|
| 268 |
+
image_paths = [image_paths[i] for i in sorted_indices]
|
| 269 |
+
text_paths = [text_paths[i] for i in sorted_indices]
|
| 270 |
+
width = [width[i] for i in sorted_indices]
|
| 271 |
+
height = [height[i] for i in sorted_indices]
|
| 272 |
+
|
| 273 |
+
# --- shard по GPU ---
|
| 274 |
+
indices = list(range(len(image_paths)))
|
| 275 |
+
indices = indices[process_index::num_processes]
|
| 276 |
+
|
| 277 |
+
image_paths = [image_paths[i] for i in indices]
|
| 278 |
+
text_paths = [text_paths[i] for i in indices]
|
| 279 |
+
width = [width[i] for i in indices]
|
| 280 |
+
height = [height[i] for i in indices]
|
| 281 |
+
|
| 282 |
+
print(f"[GPU {process_index}] обрабатывает {len(image_paths)} файлов")
|
| 283 |
+
|
| 284 |
+
process_in_chunks(image_paths, text_paths, width, height, chunk_size=1000, batch_size=batch_size)
|
| 285 |
+
|
dataset_sample.ipynb
CHANGED
|
@@ -1,170 +1,3 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
"cell_type": "code",
|
| 5 |
-
"execution_count": 3,
|
| 6 |
-
"id": "9c312df2-cb57-44f6-af54-3af6ab8f962f",
|
| 7 |
-
"metadata": {},
|
| 8 |
-
"outputs": [
|
| 9 |
-
{
|
| 10 |
-
"ename": "ModuleNotFoundError",
|
| 11 |
-
"evalue": "No module named 'numpy'",
|
| 12 |
-
"output_type": "error",
|
| 13 |
-
"traceback": [
|
| 14 |
-
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
| 15 |
-
"\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)",
|
| 16 |
-
"Cell \u001b[0;32mIn[3], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;66;03m#from datasets import load_from_disk\u001b[39;00m\n\u001b[0;32m----> 2\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mnumpy\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m \u001b[38;5;21;01mnp\u001b[39;00m\n\u001b[1;32m 3\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mtorch\u001b[39;00m\n\u001b[1;32m 4\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mPIL\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m Image\n",
|
| 17 |
-
"\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'numpy'"
|
| 18 |
-
]
|
| 19 |
-
}
|
| 20 |
-
],
|
| 21 |
-
"source": [
|
| 22 |
-
"from datasets import load_from_disk\n",
|
| 23 |
-
"import numpy as np\n",
|
| 24 |
-
"import torch\n",
|
| 25 |
-
"from PIL import Image\n",
|
| 26 |
-
"from collections import defaultdict\n",
|
| 27 |
-
"from diffusers import AutoencoderKLQwenImage\n",
|
| 28 |
-
"import gc\n",
|
| 29 |
-
"\n",
|
| 30 |
-
"def analyze_dataset_by_size(dataset_path):\n",
|
| 31 |
-
" \"\"\"\n",
|
| 32 |
-
" Группирует датасет по размерам изображений и выводит базовую информацию.\n",
|
| 33 |
-
" \"\"\"\n",
|
| 34 |
-
" # Настройка устройства и типа данных\n",
|
| 35 |
-
" dtype = torch.float16\n",
|
| 36 |
-
" device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
| 37 |
-
" \n",
|
| 38 |
-
" # Загрузка VAE модели\n",
|
| 39 |
-
" print(\"Загрузка VAE модели...\")\n",
|
| 40 |
-
" vae = AutoencoderKLQwenImage.from_pretrained(\"vae\",torch_dtype=dtype).to(device).eval()\n",
|
| 41 |
-
" shift_factor = getattr(vae.config, \"shift_factor\", 0.0)\n",
|
| 42 |
-
" if shift_factor is None:\n",
|
| 43 |
-
" shift_factor = 0.0\n",
|
| 44 |
-
" \n",
|
| 45 |
-
" scaling_factor = getattr(vae.config, \"scaling_factor\", 1.0)\n",
|
| 46 |
-
" if scaling_factor is None:\n",
|
| 47 |
-
" scaling_factor = 1.0\n",
|
| 48 |
-
" \n",
|
| 49 |
-
" mean = getattr(vae.config, \"latents_mean\", None)\n",
|
| 50 |
-
" std = getattr(vae.config, \"latents_std\", None)\n",
|
| 51 |
-
" if mean is not None and std is not None:\n",
|
| 52 |
-
" latents_std = torch.tensor(std, device=device, dtype=dtype).view(1, len(std), 1, 1)\n",
|
| 53 |
-
" latents_mean = torch.tensor(mean, device=device, dtype=dtype).view(1, len(mean), 1, 1)\n",
|
| 54 |
-
" \n",
|
| 55 |
-
" # Загружаем датасет\n",
|
| 56 |
-
" print(f\"Загрузка датасета из {dataset_path}...\")\n",
|
| 57 |
-
" dataset = load_from_disk(dataset_path)\n",
|
| 58 |
-
"\n",
|
| 59 |
-
" print(f\"Осталось примеров после фильтрации: {len(dataset)}\")\n",
|
| 60 |
-
" \n",
|
| 61 |
-
" # Группируем примеры по размерам\n",
|
| 62 |
-
" print(\"\\nГруппировка примеров по размерам...\")\n",
|
| 63 |
-
" size_to_indices = defaultdict(list)\n",
|
| 64 |
-
" \n",
|
| 65 |
-
" # Собираем примеры с одинаковыми размерами\n",
|
| 66 |
-
" # Собираем примеры с одинаковыми размерами (оптимизированная версия)\n",
|
| 67 |
-
" widths = dataset[\"width\"]\n",
|
| 68 |
-
" heights = dataset[\"height\"]\n",
|
| 69 |
-
" for i, (w, h) in enumerate(zip(widths, heights)):\n",
|
| 70 |
-
" size_to_indices[(w, h)].append(i)\n",
|
| 71 |
-
" \n",
|
| 72 |
-
" # Сортируем размеры по количеству примеров\n",
|
| 73 |
-
" print(\"\\nСортируем...\")\n",
|
| 74 |
-
" size_stats = [(size, len(indices)) for size, indices in size_to_indices.items()]\n",
|
| 75 |
-
" size_stats.sort(key=lambda x: x[1], reverse=True)\n",
|
| 76 |
-
" \n",
|
| 77 |
-
" # Выводим информацию о каждой группе и показываем первый пример\n",
|
| 78 |
-
" for size, count in size_stats:\n",
|
| 79 |
-
" width, height = size\n",
|
| 80 |
-
" first_idx = size_to_indices[size][1]\n",
|
| 81 |
-
" example = dataset[first_idx]\n",
|
| 82 |
-
" \n",
|
| 83 |
-
" print(f\"\\n--- Батч {width}x{height}: {count} примеров ---\")\n",
|
| 84 |
-
" \n",
|
| 85 |
-
" # Декодируем латентное представление для первого примера\n",
|
| 86 |
-
" latent = torch.tensor(example[\"vae\"], dtype=dtype).unsqueeze(0).to(device)\n",
|
| 87 |
-
" \n",
|
| 88 |
-
" # 1. Снова обманываем VAE, превращая картинку в \"видео из 1 кадра\" [B, C, 1, H, W]\n",
|
| 89 |
-
" if latent.ndim == 4:\n",
|
| 90 |
-
" latent = latent.unsqueeze(2)\n",
|
| 91 |
-
" \n",
|
| 92 |
-
" with torch.no_grad():\n",
|
| 93 |
-
" if latents_mean is not None and latents_std is not None:\n",
|
| 94 |
-
" latent = latent * latents_std + latents_mean\n",
|
| 95 |
-
" \n",
|
| 96 |
-
" print(f\"Min of latent_for_vae: {latent.min()}\")\n",
|
| 97 |
-
" print(f\"Max of latent_for_vae: {latent.max()}\")\n",
|
| 98 |
-
" print(f\"Mean of latent_for_vae: {latent.mean()}\")\n",
|
| 99 |
-
" print(f\"Std: {latent.std().item():.4f}\")\n",
|
| 100 |
-
" if torch.isnan(latent).any() or torch.isinf(latent).any():\n",
|
| 101 |
-
" print(\"WARNING: Raw latents contain NaN or Inf values!\")\n",
|
| 102 |
-
" \n",
|
| 103 |
-
" reconstructed_image = vae.decode(latent).sample\n",
|
| 104 |
-
" \n",
|
| 105 |
-
" # 2. Вытаскиваем обычную 3D-картинку [C, H, W] из 5D-видеотензора\n",
|
| 106 |
-
" if reconstructed_image.ndim == 5:\n",
|
| 107 |
-
" # Берем нулевой батч, все каналы, нулевой кадр, всю высоту и ширину\n",
|
| 108 |
-
" img_tensor = reconstructed_image[0, :, 0, :, :] \n",
|
| 109 |
-
" else:\n",
|
| 110 |
-
" img_tensor = reconstructed_image.squeeze(0) # На всякий случай, если VAE вернул 4D\n",
|
| 111 |
-
" \n",
|
| 112 |
-
" img_array = img_tensor.cpu().numpy()\n",
|
| 113 |
-
" img_array = np.transpose(img_array, (1, 2, 0))\n",
|
| 114 |
-
" img_array = (img_array + 1) / 2 # Нормализация к [0, 1]\n",
|
| 115 |
-
" img_array = np.clip(img_array * 255, 0, 255).astype(np.uint8) # Преобразуем в uint8 для PIL\n",
|
| 116 |
-
" \n",
|
| 117 |
-
" # Создаем PIL изображение из массива\n",
|
| 118 |
-
" pil_image = Image.fromarray(img_array)\n",
|
| 119 |
-
" print(f\"Текст: {example['text']}\")\n",
|
| 120 |
-
" print(f\"Ключи: {', '.join(example.keys())}\")\n",
|
| 121 |
-
" print(f\"latent: {latent.shape}\")\n",
|
| 122 |
-
" pil_image.save(\"1.jpg\")\n",
|
| 123 |
-
" \n",
|
| 124 |
-
" # Очистка памяти\n",
|
| 125 |
-
" if torch.cuda.is_available():\n",
|
| 126 |
-
" torch.cuda.empty_cache()\n",
|
| 127 |
-
" gc.collect()\n",
|
| 128 |
-
" \n",
|
| 129 |
-
" return size_to_indices # Возвращаем словарь с индексами по группам\n",
|
| 130 |
-
"\n",
|
| 131 |
-
"# Использование\n",
|
| 132 |
-
"if __name__ == \"__main__\":\n",
|
| 133 |
-
" # Путь к датасету\n",
|
| 134 |
-
" save_path = \"datasets/ds234_640_vae_qwen\"\n",
|
| 135 |
-
" \n",
|
| 136 |
-
" # Анализ датасета\n",
|
| 137 |
-
" size_groups = analyze_dataset_by_size(save_path)"
|
| 138 |
-
]
|
| 139 |
-
},
|
| 140 |
-
{
|
| 141 |
-
"cell_type": "code",
|
| 142 |
-
"execution_count": null,
|
| 143 |
-
"id": "74a5d11d-369f-4f25-9ee0-31d3bccd0254",
|
| 144 |
-
"metadata": {},
|
| 145 |
-
"outputs": [],
|
| 146 |
-
"source": []
|
| 147 |
-
}
|
| 148 |
-
],
|
| 149 |
-
"metadata": {
|
| 150 |
-
"kernelspec": {
|
| 151 |
-
"display_name": "Python 3 (ipykernel)",
|
| 152 |
-
"language": "python",
|
| 153 |
-
"name": "python3"
|
| 154 |
-
},
|
| 155 |
-
"language_info": {
|
| 156 |
-
"codemirror_mode": {
|
| 157 |
-
"name": "ipython",
|
| 158 |
-
"version": 3
|
| 159 |
-
},
|
| 160 |
-
"file_extension": ".py",
|
| 161 |
-
"mimetype": "text/x-python",
|
| 162 |
-
"name": "python",
|
| 163 |
-
"nbconvert_exporter": "python",
|
| 164 |
-
"pygments_lexer": "ipython3",
|
| 165 |
-
"version": "3.12.3"
|
| 166 |
-
}
|
| 167 |
-
},
|
| 168 |
-
"nbformat": 4,
|
| 169 |
-
"nbformat_minor": 5
|
| 170 |
-
}
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:774dc5b6f2f55e8b4e925e5ba984f73b18e2c096b6c1df4bfe0075aa51a56258
|
| 3 |
+
size 8190
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pipeline_sdxs.py
CHANGED
|
@@ -14,12 +14,10 @@ class SdxsPipelineOutput(BaseOutput):
|
|
| 14 |
prompt: Optional[Union[str, List[str]]] = None
|
| 15 |
|
| 16 |
class SdxsPipeline(DiffusionPipeline):
|
| 17 |
-
#
|
| 18 |
-
MAX_TEXT_TOKENS = 512
|
| 19 |
|
| 20 |
def __init__(self, vae, text_encoder, tokenizer, transformer, scheduler):
|
| 21 |
super().__init__()
|
| 22 |
-
# Регистрируем модули (с Qwen)
|
| 23 |
self.register_modules(
|
| 24 |
vae=vae,
|
| 25 |
text_encoder=text_encoder,
|
|
@@ -28,62 +26,36 @@ class SdxsPipeline(DiffusionPipeline):
|
|
| 28 |
scheduler=scheduler
|
| 29 |
)
|
| 30 |
|
| 31 |
-
self.vae_scale_factor =
|
| 32 |
-
if hasattr(self.vae.config, "block_out_channels"):
|
| 33 |
-
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
| 34 |
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
std = getattr(self.vae.config, "latents_std", None)
|
| 38 |
-
if mean is not None and std is not None:
|
| 39 |
-
self.vae_latents_mean = torch.tensor(mean).view(1, len(mean), 1, 1, 1)
|
| 40 |
-
# Внимание: Cosmos использует инвертированный std для декодирования (1.0 / std)
|
| 41 |
-
self.vae_latents_std = torch.tensor(std).view(1, len(std), 1, 1, 1)
|
| 42 |
-
else:
|
| 43 |
-
self.vae_latents_mean = None
|
| 44 |
-
self.vae_latents_std = None
|
| 45 |
-
|
| 46 |
-
# Регистрируем параметры Cosmos в шедулере (если они еще не там)
|
| 47 |
-
if self.scheduler is not None:
|
| 48 |
-
self.scheduler.register_to_config(
|
| 49 |
-
sigma_max=getattr(self.scheduler.config, "sigma_max", 80.0),
|
| 50 |
-
sigma_min=getattr(self.scheduler.config, "sigma_min", 0.002),
|
| 51 |
-
sigma_data=getattr(self.scheduler.config, "sigma_data", 1.0),
|
| 52 |
-
final_sigmas_type=getattr(self.scheduler.config, "final_sigmas_type", "sigma_min"),
|
| 53 |
-
)
|
| 54 |
-
|
| 55 |
-
@staticmethod
|
| 56 |
-
def _pad_tensor_to_length(tensor: torch.Tensor, target_len: int, dim: int = 1, pad_value: float = 0) -> torch.Tensor:
|
| 57 |
-
current_len = tensor.shape[dim]
|
| 58 |
-
if current_len >= target_len:
|
| 59 |
-
return tensor
|
| 60 |
-
pad_size = target_len - current_len
|
| 61 |
-
if tensor.dim() == 3:
|
| 62 |
-
padding = (0, 0, 0, pad_size, 0, 0)
|
| 63 |
-
elif tensor.dim() == 2:
|
| 64 |
-
padding = (0, pad_size, 0, 0)
|
| 65 |
-
else:
|
| 66 |
-
raise ValueError(f"Unsupported tensor dimension: {tensor.dim()}")
|
| 67 |
-
return torch.nn.functional.pad(tensor, padding, value=pad_value)
|
| 68 |
-
|
| 69 |
-
@torch.no_grad()
|
| 70 |
def refine_prompts(
|
| 71 |
self,
|
| 72 |
prompts: Union[str, List[str]],
|
| 73 |
system_prompt: Optional[str] = None,
|
| 74 |
temperature: float = 0.7
|
| 75 |
) -> List[str]:
|
| 76 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
device = self.device
|
| 78 |
|
|
|
|
| 79 |
if system_prompt is None:
|
| 80 |
system_prompt = (
|
| 81 |
"You are a skilled text-to-image prompt engineer whose sole function is to transform "
|
| 82 |
-
"the user's input into an
|
| 83 |
-
"**The primary subject MUST be the main focus of the revised prompt "
|
| 84 |
-
"and MUST be described in rich detail within the first sentence.** "
|
| 85 |
"Output **only** the final revised prompt, with absolutely no commentary. "
|
| 86 |
-
"Don't use cliches like warm, soft, vibrant, wildflowers.
|
| 87 |
)
|
| 88 |
|
| 89 |
pad_id = getattr(self.text_encoder.config, "pad_token_id", None) or \
|
|
@@ -93,6 +65,7 @@ class SdxsPipeline(DiffusionPipeline):
|
|
| 93 |
refined_list = []
|
| 94 |
|
| 95 |
for p in prompts_list:
|
|
|
|
| 96 |
full_text = system_prompt + p
|
| 97 |
messages = [{"role": "user", "content": [{"type": "text", "text": full_text}]}]
|
| 98 |
|
|
@@ -120,7 +93,6 @@ class SdxsPipeline(DiffusionPipeline):
|
|
| 120 |
|
| 121 |
@torch.no_grad()
|
| 122 |
def encode_text(self, text: Union[str, List[str]]) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 123 |
-
"""Qwen-specific text encoding (using chat_template and hidden_states[-2])"""
|
| 124 |
device = self.device
|
| 125 |
dtype = self.transformer.dtype
|
| 126 |
if text is None: text = ""
|
|
@@ -128,221 +100,148 @@ class SdxsPipeline(DiffusionPipeline):
|
|
| 128 |
|
| 129 |
formatted_prompts = []
|
| 130 |
for t in text:
|
|
|
|
| 131 |
messages = [{"role": "user", "content": [{"type": "text", "text": t}]}]
|
| 132 |
formatted_prompts.append(self.tokenizer.apply_chat_template(messages, add_generation_prompt=False, tokenize=False))
|
| 133 |
|
| 134 |
-
toks = self.tokenizer(
|
| 135 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 136 |
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
@torch.no_grad()
|
| 143 |
-
def image_upscale(self, image: Union[str, Image.Image, List[Union[str, Image.Image]]], batch_size: int = 1) -> List[Image.Image]:
|
| 144 |
-
images = [image] if isinstance(image, (str, Image.Image)) else image
|
| 145 |
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
if isinstance(img, str): img = Image.open(img)
|
| 149 |
-
if img.mode == "RGBA":
|
| 150 |
-
img = Image.alpha_composite(Image.new("RGBA", img.size, (255, 255, 255)), img)
|
| 151 |
-
img = img.convert("RGB")
|
| 152 |
-
|
| 153 |
-
w, h = img.size
|
| 154 |
-
pw, ph = (8 - w % 8) % 8, (8 - h % 8) % 8
|
| 155 |
-
if pw or ph:
|
| 156 |
-
padded = Image.new("RGB", (w + pw, h + ph), (255, 255, 255))
|
| 157 |
-
padded.paste(img)
|
| 158 |
-
img = padded
|
| 159 |
-
|
| 160 |
-
t = torch.from_numpy(np.array(img).astype(np.float32) / 127.5 - 1.0).permute(2, 0, 1)
|
| 161 |
-
batch_data.append((t.to(self.device, torch.float16), w, h))
|
| 162 |
-
|
| 163 |
-
unique_shapes = {t.shape for t, _, _ in batch_data}
|
| 164 |
-
step = batch_size if len(unique_shapes) == 1 else 1
|
| 165 |
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
if decoded.ndim == 5:
|
| 175 |
-
decoded = decoded.squeeze(2)
|
| 176 |
-
|
| 177 |
-
decoded = (decoded.clamp(-1, 1) + 1) / 2
|
| 178 |
-
for j, tensor in enumerate(decoded):
|
| 179 |
-
w, h = chunk[j][1], chunk[j][2]
|
| 180 |
-
arr = tensor.cpu().permute(1, 2, 0).float().numpy()
|
| 181 |
-
arr = arr[:h * 2, :w * 2]
|
| 182 |
-
output_images.append(Image.fromarray((arr * 255).astype("uint8")))
|
| 183 |
-
|
| 184 |
-
return output_images
|
| 185 |
-
|
| 186 |
@torch.no_grad()
|
| 187 |
def __call__(
|
| 188 |
self,
|
| 189 |
prompt: Optional[Union[str, List[str]]] = None,
|
| 190 |
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
prompt_attention_mask: Optional[torch.Tensor] = None,
|
| 194 |
-
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
|
| 195 |
-
latents: Optional[torch.Tensor] = None,
|
| 196 |
-
height: int = 1024,
|
| 197 |
-
width: int = 1024,
|
| 198 |
num_inference_steps: int = 40,
|
| 199 |
guidance_scale: float = 4.0,
|
| 200 |
-
generator: Optional[torch.Generator] = None,
|
| 201 |
seed: Optional[int] = None,
|
| 202 |
output_type: str = "pil",
|
| 203 |
return_dict: bool = True,
|
| 204 |
-
**kwargs,
|
| 205 |
):
|
| 206 |
device = self.device
|
| 207 |
dtype = self.transformer.dtype
|
| 208 |
-
|
| 209 |
-
if
|
| 210 |
generator = torch.Generator(device=device).manual_seed(seed)
|
| 211 |
-
|
|
|
|
|
|
|
| 212 |
do_classifier_free_guidance = guidance_scale > 1.0
|
| 213 |
|
| 214 |
-
# 1. Encode
|
| 215 |
-
|
| 216 |
-
if prompt is None: raise ValueError("`prompt` or `prompt_embeds` required.")
|
| 217 |
-
prompt_embeds, prompt_attention_mask = self.encode_text(prompt)
|
| 218 |
-
prompt_embeds = prompt_embeds.to(device=device, dtype=dtype)
|
| 219 |
-
prompt_attention_mask = prompt_attention_mask.to(device=device, dtype=torch.int64)
|
| 220 |
batch_size = prompt_embeds.shape[0]
|
| 221 |
|
| 222 |
-
# 2. Encode Negative
|
| 223 |
if do_classifier_free_guidance:
|
| 224 |
-
if
|
| 225 |
-
|
| 226 |
-
negative_prompt_embeds, negative_prompt_attention_mask = self.encode_text(neg_text)
|
| 227 |
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
negative_prompt_embeds = negative_prompt_embeds.repeat(batch_size, 1, 1)
|
| 233 |
-
negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(batch_size, 1)
|
| 234 |
-
|
| 235 |
-
max_len = max(prompt_embeds.shape[1], negative_prompt_embeds.shape[1])
|
| 236 |
-
prompt_embeds = self._pad_tensor_to_length(prompt_embeds, max_len, dim=1, pad_value=0)
|
| 237 |
-
negative_prompt_embeds = self._pad_tensor_to_length(negative_prompt_embeds, max_len, dim=1, pad_value=0)
|
| 238 |
-
prompt_attention_mask = self._pad_tensor_to_length(prompt_attention_mask, max_len, dim=1, pad_value=0)
|
| 239 |
-
negative_prompt_attention_mask = self._pad_tensor_to_length(negative_prompt_attention_mask, max_len, dim=1, pad_value=0)
|
| 240 |
-
|
| 241 |
-
text_embeddings = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
| 242 |
else:
|
| 243 |
text_embeddings = prompt_embeds
|
| 244 |
-
|
| 245 |
-
# 3. Prepare Timesteps (Cosmos specific schedule)
|
| 246 |
-
sigmas_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64
|
| 247 |
-
sigmas = torch.linspace(0, 1, num_inference_steps, dtype=sigmas_dtype)
|
| 248 |
-
self.scheduler.set_timesteps(sigmas=sigmas, device=device)
|
| 249 |
-
timesteps = self.scheduler.timesteps
|
| 250 |
-
|
| 251 |
-
# Защита от деления на ноль на последнем шаге
|
| 252 |
-
if self.scheduler.config.get("final_sigmas_type", "zero") == "sigma_min":
|
| 253 |
-
self.scheduler.sigmas[-1] = self.scheduler.sigmas[-2]
|
| 254 |
-
if self.scheduler.sigmas[-1] == 0.0:
|
| 255 |
-
self.scheduler.sigmas[-1] = 1e-4
|
| 256 |
|
| 257 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
| 258 |
latent_h = height // self.vae_scale_factor
|
| 259 |
latent_w = width // self.vae_scale_factor
|
| 260 |
in_channels = self.transformer.config.in_channels
|
| 261 |
-
sigma_max = getattr(self.scheduler.config, "sigma_max", 80.0)
|
| 262 |
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
|
|
|
| 269 |
|
| 270 |
-
#
|
| 271 |
-
padding_mask = torch.zeros((1, 1,
|
| 272 |
|
| 273 |
-
#
|
| 274 |
-
for i
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
# Защита от деления на 0 при вычислении current_t
|
| 278 |
-
if current_sigma == 0.0:
|
| 279 |
-
current_sigma = torch.tensor(1e-4, dtype=current_sigma.dtype, device=device)
|
| 280 |
-
|
| 281 |
-
current_t = current_sigma / (current_sigma + 1.0)
|
| 282 |
-
c_in = 1.0 - current_t
|
| 283 |
-
c_skip = 1.0 - current_t
|
| 284 |
-
c_out = -current_t
|
| 285 |
|
|
|
|
| 286 |
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
| 287 |
-
latent_model_input = (latent_model_input * c_in).to(dtype)
|
| 288 |
|
| 289 |
-
#
|
| 290 |
-
|
| 291 |
-
timestep_tensor = torch.tensor(
|
| 292 |
-
[t_val],
|
| 293 |
-
device=device,
|
| 294 |
-
dtype=dtype
|
| 295 |
-
).view(1, 1, 1, 1, 1).expand(latent_model_input.shape[0], 1, 1, 1, 1)
|
| 296 |
|
| 297 |
-
|
|
|
|
|
|
|
| 298 |
hidden_states=latent_model_input,
|
| 299 |
-
timestep=
|
| 300 |
encoder_hidden_states=text_embeddings,
|
| 301 |
padding_mask=padding_mask,
|
| 302 |
return_dict=False,
|
| 303 |
)[0]
|
| 304 |
-
|
| 305 |
-
batched_latents = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
| 306 |
-
noise_pred = (c_skip * batched_latents + c_out * model_out.float()).to(dtype)
|
| 307 |
-
|
| 308 |
if do_classifier_free_guidance:
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
|
| 312 |
-
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 316 |
if output_type == "latent":
|
| 317 |
-
if not return_dict: return (latents, prompt)
|
| 318 |
return SdxsPipelineOutput(images=latents)
|
| 319 |
-
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
|
| 323 |
l_mean = torch.tensor(self.vae.config.latents_mean).view(1, -1, 1, 1, 1).to(device, dtype)
|
| 324 |
l_std = torch.tensor(self.vae.config.latents_std).view(1, -1, 1, 1, 1).to(device, dtype)
|
| 325 |
-
|
| 326 |
-
# Оригинальная формула: делим на инвертированный std (что равноценно умножению на std)
|
| 327 |
-
#latents_std_inv = 1.0 / l_std
|
| 328 |
latents = latents * l_std + l_mean
|
| 329 |
|
| 330 |
-
|
|
|
|
|
|
|
| 331 |
|
| 332 |
if image_output.ndim == 5:
|
| 333 |
-
image_output = image_output.squeeze(2)
|
| 334 |
-
|
| 335 |
image_output = (image_output.clamp(-1, 1) + 1) / 2
|
| 336 |
image_np = image_output.cpu().permute(0, 2, 3, 1).float().numpy()
|
| 337 |
-
|
| 338 |
-
|
| 339 |
-
image_np = np.nan_to_num(image_np, nan=0.0, posinf=1.0, neginf=0.0)
|
| 340 |
-
|
| 341 |
if output_type == "pil":
|
| 342 |
-
images = [
|
| 343 |
else:
|
| 344 |
images = image_np
|
| 345 |
-
|
| 346 |
-
|
| 347 |
-
return (images,)
|
| 348 |
-
return SdxsPipelineOutput(images=images)
|
|
|
|
| 14 |
prompt: Optional[Union[str, List[str]]] = None
|
| 15 |
|
| 16 |
class SdxsPipeline(DiffusionPipeline):
|
| 17 |
+
MAX_TEXT_TOKENS = 400 # не Соответствует max_length в обучении
|
|
|
|
| 18 |
|
| 19 |
def __init__(self, vae, text_encoder, tokenizer, transformer, scheduler):
|
| 20 |
super().__init__()
|
|
|
|
| 21 |
self.register_modules(
|
| 22 |
vae=vae,
|
| 23 |
text_encoder=text_encoder,
|
|
|
|
| 26 |
scheduler=scheduler
|
| 27 |
)
|
| 28 |
|
| 29 |
+
self.vae_scale_factor = 8
|
|
|
|
|
|
|
| 30 |
|
| 31 |
+
|
| 32 |
+
@torch.no_grad()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
def refine_prompts(
|
| 34 |
self,
|
| 35 |
prompts: Union[str, List[str]],
|
| 36 |
system_prompt: Optional[str] = None,
|
| 37 |
temperature: float = 0.7
|
| 38 |
) -> List[str]:
|
| 39 |
+
"""
|
| 40 |
+
Refines a list of prompts using the Text Encoder (LLM).
|
| 41 |
+
|
| 42 |
+
Args:
|
| 43 |
+
prompts: Single prompt string or list of prompts.
|
| 44 |
+
system_prompt: Custom instruction for the LLM. If None, uses default aesthetic enhancer.
|
| 45 |
+
temperature: Sampling temperature for generation.
|
| 46 |
+
|
| 47 |
+
Returns:
|
| 48 |
+
List of refined prompts.
|
| 49 |
+
"""
|
| 50 |
device = self.device
|
| 51 |
|
| 52 |
+
# Default system prompt if none provided
|
| 53 |
if system_prompt is None:
|
| 54 |
system_prompt = (
|
| 55 |
"You are a skilled text-to-image prompt engineer whose sole function is to transform "
|
| 56 |
+
"the user's input into an aesthetic, detailed, and visually descriptive three-sentence output. "
|
|
|
|
|
|
|
| 57 |
"Output **only** the final revised prompt, with absolutely no commentary. "
|
| 58 |
+
"Don't use cliches like warm, soft, vibrant, wildflowers. User input prompt: "
|
| 59 |
)
|
| 60 |
|
| 61 |
pad_id = getattr(self.text_encoder.config, "pad_token_id", None) or \
|
|
|
|
| 65 |
refined_list = []
|
| 66 |
|
| 67 |
for p in prompts_list:
|
| 68 |
+
# Prepend system prompt to user input
|
| 69 |
full_text = system_prompt + p
|
| 70 |
messages = [{"role": "user", "content": [{"type": "text", "text": full_text}]}]
|
| 71 |
|
|
|
|
| 93 |
|
| 94 |
@torch.no_grad()
|
| 95 |
def encode_text(self, text: Union[str, List[str]]) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
|
|
| 96 |
device = self.device
|
| 97 |
dtype = self.transformer.dtype
|
| 98 |
if text is None: text = ""
|
|
|
|
| 100 |
|
| 101 |
formatted_prompts = []
|
| 102 |
for t in text:
|
| 103 |
+
# Повторяем логику чат-шаблона из обучения
|
| 104 |
messages = [{"role": "user", "content": [{"type": "text", "text": t}]}]
|
| 105 |
formatted_prompts.append(self.tokenizer.apply_chat_template(messages, add_generation_prompt=False, tokenize=False))
|
| 106 |
|
| 107 |
+
toks = self.tokenizer(
|
| 108 |
+
formatted_prompts,
|
| 109 |
+
padding="max_length",
|
| 110 |
+
max_length=self.MAX_TEXT_TOKENS,
|
| 111 |
+
truncation=True,
|
| 112 |
+
return_tensors="pt"
|
| 113 |
+
).to(device)
|
| 114 |
|
| 115 |
+
outputs = self.text_encoder(
|
| 116 |
+
input_ids=toks.input_ids,
|
| 117 |
+
attention_mask=toks.attention_mask,
|
| 118 |
+
output_hidden_states=True
|
| 119 |
+
)
|
|
|
|
|
|
|
|
|
|
| 120 |
|
| 121 |
+
# Берем предпоследний слой (-2) как в обучении
|
| 122 |
+
last_hidden = outputs.hidden_states[-2].to(dtype=dtype)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 123 |
|
| 124 |
+
# Обнуляем паддинги для честности (как в обучении)
|
| 125 |
+
lengths = toks.attention_mask.sum(dim=1)
|
| 126 |
+
for i, length in enumerate(lengths):
|
| 127 |
+
last_hidden[i, length:] = 0
|
| 128 |
+
|
| 129 |
+
return last_hidden, toks.attention_mask.to(dtype=torch.int64)
|
| 130 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 131 |
@torch.no_grad()
|
| 132 |
def __call__(
|
| 133 |
self,
|
| 134 |
prompt: Optional[Union[str, List[str]]] = None,
|
| 135 |
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 136 |
+
height: int = 1152,
|
| 137 |
+
width: int = 768,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 138 |
num_inference_steps: int = 40,
|
| 139 |
guidance_scale: float = 4.0,
|
|
|
|
| 140 |
seed: Optional[int] = None,
|
| 141 |
output_type: str = "pil",
|
| 142 |
return_dict: bool = True,
|
|
|
|
| 143 |
):
|
| 144 |
device = self.device
|
| 145 |
dtype = self.transformer.dtype
|
| 146 |
+
|
| 147 |
+
if seed is not None:
|
| 148 |
generator = torch.Generator(device=device).manual_seed(seed)
|
| 149 |
+
else:
|
| 150 |
+
generator = None
|
| 151 |
+
|
| 152 |
do_classifier_free_guidance = guidance_scale > 1.0
|
| 153 |
|
| 154 |
+
# 1. Encode Prompts
|
| 155 |
+
prompt_embeds, prompt_mask = self.encode_text(prompt)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 156 |
batch_size = prompt_embeds.shape[0]
|
| 157 |
|
|
|
|
| 158 |
if do_classifier_free_guidance:
|
| 159 |
+
neg_text = negative_prompt if negative_prompt is not None else ([""] * batch_size)
|
| 160 |
+
neg_embeds, neg_mask = self.encode_text(neg_text)
|
|
|
|
| 161 |
|
| 162 |
+
# Конкатенация для батч-генерации (uncond + cond)
|
| 163 |
+
text_embeddings = torch.cat([neg_embeds, prompt_embeds], dim=0)
|
| 164 |
+
# В вашем обучении padding_mask в модель передавался как нули,
|
| 165 |
+
# но внутри трансформера обычно используется encoder_attention_mask
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 166 |
else:
|
| 167 |
text_embeddings = prompt_embeds
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 168 |
|
| 169 |
+
# 2. Prepare Timesteps (Flow Matching: от 1.0 к 0.0)
|
| 170 |
+
# В обучении t=1 был шумом, t=0 — данными.
|
| 171 |
+
timesteps = torch.linspace(1.0, 0.0, num_inference_steps + 1, device=device, dtype=dtype)
|
| 172 |
+
|
| 173 |
+
# 3. Prepare Latents
|
| 174 |
latent_h = height // self.vae_scale_factor
|
| 175 |
latent_w = width // self.vae_scale_factor
|
| 176 |
in_channels = self.transformer.config.in_channels
|
|
|
|
| 177 |
|
| 178 |
+
# В Flow Matching начальный шум имеет стандартное отклонение 1.0
|
| 179 |
+
latents = torch.randn(
|
| 180 |
+
(batch_size, in_channels, 1, latent_h, latent_w),
|
| 181 |
+
generator=generator,
|
| 182 |
+
device=device,
|
| 183 |
+
dtype=dtype
|
| 184 |
+
)
|
| 185 |
|
| 186 |
+
# Пустая маска как в обучении
|
| 187 |
+
padding_mask = torch.zeros((1, 1, latent_h, latent_w), device=device, dtype=dtype)
|
| 188 |
|
| 189 |
+
# 4. Denoising Loop (Euler Method)
|
| 190 |
+
for i in tqdm(range(num_inference_steps), desc="Sampling"):
|
| 191 |
+
t_curr = timesteps[i]
|
| 192 |
+
t_next = timesteps[i+1]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 193 |
|
| 194 |
+
# Подготовка входа (CFG)
|
| 195 |
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
|
|
|
| 196 |
|
| 197 |
+
# Модель обучалась на t.flatten(), передаем как вектор [B]
|
| 198 |
+
t_vec = t_curr.expand(latent_model_input.shape[0])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 199 |
|
| 200 |
+
# Предсказание "скорости" (v)
|
| 201 |
+
# Т.к. в обучении target = noise - clean, модель предсказывает направление к шуму
|
| 202 |
+
model_output = self.transformer(
|
| 203 |
hidden_states=latent_model_input,
|
| 204 |
+
timestep=t_vec,
|
| 205 |
encoder_hidden_states=text_embeddings,
|
| 206 |
padding_mask=padding_mask,
|
| 207 |
return_dict=False,
|
| 208 |
)[0]
|
| 209 |
+
|
|
|
|
|
|
|
|
|
|
| 210 |
if do_classifier_free_guidance:
|
| 211 |
+
v_uncond, v_cond = model_output.chunk(2)
|
| 212 |
+
v_final = v_uncond + guidance_scale * (v_cond - v_uncond)
|
| 213 |
+
else:
|
| 214 |
+
v_final = model_output
|
| 215 |
+
|
| 216 |
+
# Euler шаг: x_{t-1} = x_t + (t_next - t_curr) * v
|
| 217 |
+
# Поскольку t идет от 1 к 0, (t_next - t_curr) будет отрицательным,
|
| 218 |
+
# что правильно двигает нас от шума к данным.
|
| 219 |
+
latents = latents + (t_next - t_curr) * v_final
|
| 220 |
+
|
| 221 |
+
# 5. Decode
|
| 222 |
if output_type == "latent":
|
|
|
|
| 223 |
return SdxsPipelineOutput(images=latents)
|
| 224 |
+
|
| 225 |
+
# Применяем де-нормализацию VAE как в обучении
|
| 226 |
+
if getattr(self.vae.config, "latents_std", None) is not None:
|
|
|
|
| 227 |
l_mean = torch.tensor(self.vae.config.latents_mean).view(1, -1, 1, 1, 1).to(device, dtype)
|
| 228 |
l_std = torch.tensor(self.vae.config.latents_std).view(1, -1, 1, 1, 1).to(device, dtype)
|
|
|
|
|
|
|
|
|
|
| 229 |
latents = latents * l_std + l_mean
|
| 230 |
|
| 231 |
+
# Декодируем
|
| 232 |
+
with torch.no_grad():
|
| 233 |
+
image_output = self.vae.decode(latents.to(self.vae.dtype), return_dict=False)[0]
|
| 234 |
|
| 235 |
if image_output.ndim == 5:
|
| 236 |
+
image_output = image_output.squeeze(2) # Убираем временную ось (Frames=1)
|
| 237 |
+
|
| 238 |
image_output = (image_output.clamp(-1, 1) + 1) / 2
|
| 239 |
image_np = image_output.cpu().permute(0, 2, 3, 1).float().numpy()
|
| 240 |
+
image_np = np.nan_to_num(image_np, nan=0.0)
|
| 241 |
+
|
|
|
|
|
|
|
| 242 |
if output_type == "pil":
|
| 243 |
+
images = [Image.fromarray((img * 255).round().astype("uint8")) for img in image_np]
|
| 244 |
else:
|
| 245 |
images = image_np
|
| 246 |
+
|
| 247 |
+
return SdxsPipelineOutput(images=images, prompt=prompt)
|
|
|
|
|
|
refined.jpg
ADDED
|
Git LFS Details
|
scheduler/.ipynb_checkpoints/scheduler_config-checkpoint.json
DELETED
|
@@ -1,22 +0,0 @@
|
|
| 1 |
-
{
|
| 2 |
-
"_class_name": "FlowMatchEulerDiscreteScheduler",
|
| 3 |
-
"_diffusers_version": "0.34.0.dev0",
|
| 4 |
-
"base_image_seq_len": 256,
|
| 5 |
-
"base_shift": 0.5,
|
| 6 |
-
"final_sigmas_type": "sigma_min",
|
| 7 |
-
"invert_sigmas": false,
|
| 8 |
-
"max_image_seq_len": 4096,
|
| 9 |
-
"max_shift": 1.15,
|
| 10 |
-
"num_train_timesteps": 1000,
|
| 11 |
-
"shift": 1.0,
|
| 12 |
-
"shift_terminal": null,
|
| 13 |
-
"sigma_data": 1.0,
|
| 14 |
-
"sigma_max": 80.0,
|
| 15 |
-
"sigma_min": 0.002,
|
| 16 |
-
"stochastic_sampling": false,
|
| 17 |
-
"time_shift_type": "exponential",
|
| 18 |
-
"use_beta_sigmas": false,
|
| 19 |
-
"use_dynamic_shifting": false,
|
| 20 |
-
"use_exponential_sigmas": false,
|
| 21 |
-
"use_karras_sigmas": true
|
| 22 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
scheduler/scheduler_config.json
CHANGED
|
@@ -1,22 +1,3 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
"base_image_seq_len": 256,
|
| 5 |
-
"base_shift": 0.5,
|
| 6 |
-
"final_sigmas_type": "sigma_min",
|
| 7 |
-
"invert_sigmas": false,
|
| 8 |
-
"max_image_seq_len": 4096,
|
| 9 |
-
"max_shift": 1.15,
|
| 10 |
-
"num_train_timesteps": 1000,
|
| 11 |
-
"shift": 1.0,
|
| 12 |
-
"shift_terminal": null,
|
| 13 |
-
"sigma_data": 1.0,
|
| 14 |
-
"sigma_max": 80.0,
|
| 15 |
-
"sigma_min": 0.002,
|
| 16 |
-
"stochastic_sampling": false,
|
| 17 |
-
"time_shift_type": "exponential",
|
| 18 |
-
"use_beta_sigmas": false,
|
| 19 |
-
"use_dynamic_shifting": false,
|
| 20 |
-
"use_exponential_sigmas": false,
|
| 21 |
-
"use_karras_sigmas": true
|
| 22 |
-
}
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:65b3e9ccde6e3727aab1c612e7279599f861aec2fb9354880ab9ef8753c492b6
|
| 3 |
+
size 485
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
test.ipynb
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:feff1f3730b8dae616e3ffd24b2f74dcd9c6776c46e00ac72018e0de74785d06
|
| 3 |
+
size 18136603
|
train-sdxs2b.py
CHANGED
|
@@ -17,7 +17,7 @@ from torch.utils.data import DataLoader, Sampler
|
|
| 17 |
from torch.optim.lr_scheduler import LambdaLR
|
| 18 |
from collections import defaultdict
|
| 19 |
from accelerate import Accelerator
|
| 20 |
-
from datasets import load_from_disk
|
| 21 |
from tqdm import tqdm
|
| 22 |
from PIL import Image, ImageOps
|
| 23 |
from torch.utils.checkpoint import checkpoint
|
|
@@ -33,7 +33,7 @@ os.environ["NCCL_IB_DISABLE"] = "1" # comment this on H100!
|
|
| 33 |
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
| 34 |
|
| 35 |
# --------------------------- Параметры ---------------------------
|
| 36 |
-
ds_path = "/
|
| 37 |
project = "transformer"
|
| 38 |
|
| 39 |
gpu_mem_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
|
|
@@ -81,8 +81,8 @@ torch.backends.cuda.matmul.allow_tf32 = True
|
|
| 81 |
torch.backends.cudnn.allow_tf32 = True
|
| 82 |
torch.backends.cuda.enable_flash_sdp(True)
|
| 83 |
torch.backends.cuda.enable_mem_efficient_sdp(True)
|
| 84 |
-
torch.backends.cuda.enable_math_sdp(
|
| 85 |
-
save_barrier = 1.
|
| 86 |
warmup_percent = 0.0025
|
| 87 |
betta2 = 0.997
|
| 88 |
eps = 1e-6
|
|
@@ -223,7 +223,7 @@ def encode_texts(text, max_length=max_length):
|
|
| 223 |
for i, length in enumerate(lengths):
|
| 224 |
hidden[i, length:] = 0
|
| 225 |
|
| 226 |
-
return hidden, toks.attention_mask.to(dtype=torch.
|
| 227 |
|
| 228 |
@torch.no_grad()
|
| 229 |
def encode_texts_fast(text, max_length=max_length):
|
|
@@ -244,7 +244,7 @@ def encode_texts_fast(text, max_length=max_length):
|
|
| 244 |
for i, length in enumerate(lengths):
|
| 245 |
last_hidden[i, length:] = 0
|
| 246 |
|
| 247 |
-
return last_hidden, toks.attention_mask.to(dtype=torch.
|
| 248 |
|
| 249 |
shift_factor = getattr(vae.config, "shift_factor", 0.0)
|
| 250 |
if shift_factor is None:
|
|
@@ -375,7 +375,7 @@ def get_fixed_samples_by_resolution(dataset, samples_per_group=1):
|
|
| 375 |
masks = torch.tensor(
|
| 376 |
np.array([item["attention_mask"] for item in samples_data]),
|
| 377 |
device=device,
|
| 378 |
-
dtype=torch.
|
| 379 |
)
|
| 380 |
else:
|
| 381 |
embeddings, masks = encode_texts(texts,max_length)
|
|
@@ -388,7 +388,30 @@ def get_fixed_samples_by_resolution(dataset, samples_per_group=1):
|
|
| 388 |
if limit > 0:
|
| 389 |
dataset = load_from_disk(ds_path).select(range(limit))
|
| 390 |
else:
|
| 391 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 392 |
|
| 393 |
print(f"images: {len(dataset)}")
|
| 394 |
|
|
@@ -424,7 +447,7 @@ def collate_fn_simple(batch):
|
|
| 424 |
]
|
| 425 |
|
| 426 |
embeddings, attention_mask = encode_texts(texts,max_length)
|
| 427 |
-
attention_mask = attention_mask.to(dtype=torch.
|
| 428 |
|
| 429 |
return latents, embeddings, attention_mask
|
| 430 |
|
|
@@ -552,7 +575,7 @@ def get_negative_embedding(neg_prompt="", batch_size=1):
|
|
| 552 |
hidden_dim = 2048
|
| 553 |
seq_len = max_length
|
| 554 |
empty_emb = torch.zeros((batch_size, seq_len, hidden_dim), dtype=dtype, device=device)
|
| 555 |
-
empty_mask = torch.ones((batch_size, seq_len), dtype=torch.
|
| 556 |
return empty_emb, empty_mask
|
| 557 |
|
| 558 |
uncond_emb, uncond_mask = encode_texts([neg_prompt],max_length)
|
|
|
|
| 17 |
from torch.optim.lr_scheduler import LambdaLR
|
| 18 |
from collections import defaultdict
|
| 19 |
from accelerate import Accelerator
|
| 20 |
+
from datasets import load_from_disk,concatenate_datasets
|
| 21 |
from tqdm import tqdm
|
| 22 |
from PIL import Image, ImageOps
|
| 23 |
from torch.utils.checkpoint import checkpoint
|
|
|
|
| 33 |
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
| 34 |
|
| 35 |
# --------------------------- Параметры ---------------------------
|
| 36 |
+
ds_path = "datasets/dsb_640_vae_qwen_temp"
|
| 37 |
project = "transformer"
|
| 38 |
|
| 39 |
gpu_mem_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
|
|
|
|
| 81 |
torch.backends.cudnn.allow_tf32 = True
|
| 82 |
torch.backends.cuda.enable_flash_sdp(True)
|
| 83 |
torch.backends.cuda.enable_mem_efficient_sdp(True)
|
| 84 |
+
torch.backends.cuda.enable_math_sdp(True)
|
| 85 |
+
save_barrier = 1.4
|
| 86 |
warmup_percent = 0.0025
|
| 87 |
betta2 = 0.997
|
| 88 |
eps = 1e-6
|
|
|
|
| 223 |
for i, length in enumerate(lengths):
|
| 224 |
hidden[i, length:] = 0
|
| 225 |
|
| 226 |
+
return hidden, toks.attention_mask.to(dtype=torch.bool)
|
| 227 |
|
| 228 |
@torch.no_grad()
|
| 229 |
def encode_texts_fast(text, max_length=max_length):
|
|
|
|
| 244 |
for i, length in enumerate(lengths):
|
| 245 |
last_hidden[i, length:] = 0
|
| 246 |
|
| 247 |
+
return last_hidden, toks.attention_mask.to(dtype=torch.bool)
|
| 248 |
|
| 249 |
shift_factor = getattr(vae.config, "shift_factor", 0.0)
|
| 250 |
if shift_factor is None:
|
|
|
|
| 375 |
masks = torch.tensor(
|
| 376 |
np.array([item["attention_mask"] for item in samples_data]),
|
| 377 |
device=device,
|
| 378 |
+
dtype=torch.bool
|
| 379 |
)
|
| 380 |
else:
|
| 381 |
embeddings, masks = encode_texts(texts,max_length)
|
|
|
|
| 388 |
if limit > 0:
|
| 389 |
dataset = load_from_disk(ds_path).select(range(limit))
|
| 390 |
else:
|
| 391 |
+
print(">>> Поиск чанков датасета...")
|
| 392 |
+
chunks = []
|
| 393 |
+
for d in os.listdir(ds_path):
|
| 394 |
+
full_p = os.path.join(ds_path, d)
|
| 395 |
+
if os.path.isdir(full_p):
|
| 396 |
+
chunks.append(full_p)
|
| 397 |
+
|
| 398 |
+
if not chunks:
|
| 399 |
+
print("❌ Чанки не найдены!")
|
| 400 |
+
|
| 401 |
+
print(f">>> Найдено чанков: {len(chunks)}. Начинаю загрузку и объединение...")
|
| 402 |
+
|
| 403 |
+
# 2. Ленивая загрузка всех чанков
|
| 404 |
+
# load_from_disk не ест RAM, пока мы не обращаемся к данным
|
| 405 |
+
ds_list = []
|
| 406 |
+
for c in chunks:
|
| 407 |
+
try:
|
| 408 |
+
ds_list.append(load_from_disk(c))
|
| 409 |
+
except Exception as e:
|
| 410 |
+
print(f"⚠️ Ошибка загрузки чанка {c}: {e}")
|
| 411 |
+
|
| 412 |
+
# 3. Конкатенация (создает виртуальный объединенный датасет)
|
| 413 |
+
dataset = concatenate_datasets(ds_list)
|
| 414 |
+
#dataset = load_from_disk(ds_path)
|
| 415 |
|
| 416 |
print(f"images: {len(dataset)}")
|
| 417 |
|
|
|
|
| 447 |
]
|
| 448 |
|
| 449 |
embeddings, attention_mask = encode_texts(texts,max_length)
|
| 450 |
+
attention_mask = attention_mask.to(dtype=torch.bool)
|
| 451 |
|
| 452 |
return latents, embeddings, attention_mask
|
| 453 |
|
|
|
|
| 575 |
hidden_dim = 2048
|
| 576 |
seq_len = max_length
|
| 577 |
empty_emb = torch.zeros((batch_size, seq_len, hidden_dim), dtype=dtype, device=device)
|
| 578 |
+
empty_mask = torch.ones((batch_size, seq_len), dtype=torch.bool, device=device)
|
| 579 |
return empty_emb, empty_mask
|
| 580 |
|
| 581 |
uncond_emb, uncond_mask = encode_texts([neg_prompt],max_length)
|
transformer/diffusion_pytorch_model.safetensors
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ee77d7083d1f968607fbcc531deae347d72c2fb229bbe40356e44a8edae26aec
|
| 3 |
+
size 3912877104
|
wandb/debug-internal.log
DELETED
|
The diff for this file is too large to render.
See raw diff
|
|
|
wandb/debug-internal.log
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
run-20260513_080408-xhrf3max/logs/debug-internal.log
|
wandb/debug.log
DELETED
|
@@ -1,19 +0,0 @@
|
|
| 1 |
-
2026-05-05 07:53:13,457 INFO MainThread:43955 [wandb_setup.py:_flush():81] Current SDK version is 0.26.1
|
| 2 |
-
2026-05-05 07:53:13,457 INFO MainThread:43955 [wandb_setup.py:_flush():81] Configure stats pid to 43955
|
| 3 |
-
2026-05-05 07:53:13,457 INFO MainThread:43955 [wandb_setup.py:_flush():81] Loading settings from environment variables
|
| 4 |
-
2026-05-05 07:53:13,457 INFO MainThread:43955 [wandb_init.py:setup_run_log_directory():723] Logging user logs to /workspace/2b/wandb/run-20260505_075313-ti70f47q/logs/debug.log
|
| 5 |
-
2026-05-05 07:53:13,458 INFO MainThread:43955 [wandb_init.py:setup_run_log_directory():724] Logging internal logs to /workspace/2b/wandb/run-20260505_075313-ti70f47q/logs/debug-internal.log
|
| 6 |
-
2026-05-05 07:53:13,458 INFO MainThread:43955 [wandb_init.py:init():850] calling init triggers
|
| 7 |
-
2026-05-05 07:53:13,458 INFO MainThread:43955 [wandb_init.py:init():855] wandb.init called with sweep_config: {}
|
| 8 |
-
config: {'batch_size': 24, 'base_learning_rate': 1.3333333333333335e-05, 'num_epochs': 1, 'optimizer_type': 'adafactor', '_wandb': {}}
|
| 9 |
-
2026-05-05 07:53:13,458 INFO MainThread:43955 [wandb_init.py:init():898] starting backend
|
| 10 |
-
2026-05-05 07:53:13,663 INFO MainThread:43955 [wandb_init.py:init():913] sending inform_init request
|
| 11 |
-
2026-05-05 07:53:13,842 INFO MainThread:43955 [wandb_init.py:init():918] backend started and connected
|
| 12 |
-
2026-05-05 07:53:13,844 INFO MainThread:43955 [wandb_init.py:init():988] updated telemetry
|
| 13 |
-
2026-05-05 07:53:13,845 INFO MainThread:43955 [wandb_init.py:init():1011] communicating run to backend with 90.0 second timeout
|
| 14 |
-
2026-05-05 07:53:14,174 INFO MainThread:43955 [wandb_init.py:init():1056] starting run threads in backend
|
| 15 |
-
2026-05-05 07:53:14,261 INFO MainThread:43955 [wandb_run.py:_console_start():2554] atexit reg
|
| 16 |
-
2026-05-05 07:53:14,262 INFO MainThread:43955 [wandb_run.py:_redirect():2403] redirect: wrap_raw
|
| 17 |
-
2026-05-05 07:53:14,262 INFO MainThread:43955 [wandb_run.py:_redirect():2472] Wrapping output streams.
|
| 18 |
-
2026-05-05 07:53:14,262 INFO MainThread:43955 [wandb_run.py:_redirect():2495] Redirects installed.
|
| 19 |
-
2026-05-05 07:53:14,267 INFO MainThread:43955 [wandb_init.py:init():1094] run started, returning control to user process
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
wandb/debug.log
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
run-20260513_080408-xhrf3max/logs/debug.log
|