babkasotona commited on
Commit
58bb2b7
·
verified ·
1 Parent(s): 9074517

Upload folder using huggingface_hub

Browse files
Files changed (39) hide show
  1. .gitattributes +4 -0
  2. .gitignore +21 -0
  3. Untitled.ipynb +139 -0
  4. dataset.py +300 -0
  5. dataset_sample.ipynb +170 -0
  6. model_index.json +24 -0
  7. pipeline_sdxs.py +348 -0
  8. pipeline_sdxs_t5.py +291 -0
  9. scheduler/.ipynb_checkpoints/scheduler_config-checkpoint.json +22 -0
  10. scheduler/scheduler_config.json +22 -0
  11. t.py +116 -0
  12. test.ipynb +3 -0
  13. text_encoder/.ipynb_checkpoints/config-checkpoint.json +101 -0
  14. text_encoder/config.json +101 -0
  15. text_encoder/model.safetensors +3 -0
  16. tokenizer/chat_template.jinja +154 -0
  17. tokenizer/tokenizer.json +3 -0
  18. tokenizer/tokenizer_config.json +32 -0
  19. train-Copy1.py +924 -0
  20. transformer/config.json +37 -0
  21. transformer/diffusion_pytorch_model.safetensors +3 -0
  22. vae/.ipynb_checkpoints/config-checkpoint.json +56 -0
  23. vae/config.json +56 -0
  24. vae/diffusion_pytorch_model.safetensors +3 -0
  25. wandb/debug-cli.root.log +0 -0
  26. wandb/debug-internal.log +0 -0
  27. wandb/debug.log +19 -0
  28. wandb/offline-run-20260428_132658-o9052r27/files/requirements.txt +117 -0
  29. wandb/offline-run-20260428_132658-o9052r27/logs/debug-core.log +14 -0
  30. wandb/offline-run-20260428_132658-o9052r27/logs/debug-internal.log +15 -0
  31. wandb/offline-run-20260428_132658-o9052r27/logs/debug.log +21 -0
  32. wandb/offline-run-20260428_132658-o9052r27/run-o9052r27.wandb +0 -0
  33. wandb/run-20260428_171645-wt40fdyx/files/output.log +385 -0
  34. wandb/run-20260428_171645-wt40fdyx/files/requirements.txt +117 -0
  35. wandb/run-20260428_171645-wt40fdyx/files/wandb-metadata.json +46 -0
  36. wandb/run-20260428_171645-wt40fdyx/logs/debug-core.log +7 -0
  37. wandb/run-20260428_171645-wt40fdyx/logs/debug-internal.log +0 -0
  38. wandb/run-20260428_171645-wt40fdyx/logs/debug.log +19 -0
  39. wandb/run-20260428_171645-wt40fdyx/run-wt40fdyx.wandb +3 -0
.gitattributes CHANGED
@@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ tokenizer/tokenizer.json filter=lfs diff=lfs merge=lfs -text
37
+ media/refined.jpg filter=lfs diff=lfs merge=lfs -text
38
+ test.ipynb filter=lfs diff=lfs merge=lfs -text
39
+ wandb/run-20260428_171645-wt40fdyx/run-wt40fdyx.wandb filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Jupyter Notebook
2
+ __pycache__/
3
+ *.pyc
4
+ .ipynb_checkpoints/
5
+ *.ipynb_checkpoints/*
6
+ .ipynb_checkpoints/*
7
+ src/samples
8
+ # cache
9
+ cache
10
+ datasets
11
+ test
12
+ wandb
13
+ nohup.out
14
+ samples/
15
+ transformer/
16
+ *.jpg
17
+ *.png
18
+ datasets/
19
+ samples/
20
+ *.jpg
21
+ train.py
Untitled.ipynb ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "id": "7e8f9dc5-d07a-4538-bc03-8953412a72fa",
7
+ "metadata": {},
8
+ "outputs": [
9
+ {
10
+ "name": "stderr",
11
+ "output_type": "stream",
12
+ "text": [
13
+ "Keyword arguments {'safety_checker': <__main__.DummyCosmosSafetyChecker object at 0x7f7e8c3fb620>} are not expected by SdxsPipeline and will be ignored.\n"
14
+ ]
15
+ },
16
+ {
17
+ "data": {
18
+ "application/vnd.jupyter.widget-view+json": {
19
+ "model_id": "99e2522e93064308b5dd34923a133c39",
20
+ "version_major": 2,
21
+ "version_minor": 0
22
+ },
23
+ "text/plain": [
24
+ "Loading pipeline components...: 0%| | 0/5 [00:00<?, ?it/s]"
25
+ ]
26
+ },
27
+ "metadata": {},
28
+ "output_type": "display_data"
29
+ },
30
+ {
31
+ "data": {
32
+ "application/vnd.jupyter.widget-view+json": {
33
+ "model_id": "def23cc245b2470a9012f70d5e4c78ed",
34
+ "version_major": 2,
35
+ "version_minor": 0
36
+ },
37
+ "text/plain": [
38
+ "Loading weights: 0%| | 0/195 [00:00<?, ?it/s]"
39
+ ]
40
+ },
41
+ "metadata": {},
42
+ "output_type": "display_data"
43
+ },
44
+ {
45
+ "name": "stderr",
46
+ "output_type": "stream",
47
+ "text": [
48
+ "The config attributes {'final_sigmas_type': 'sigma_min', 'sigma_data': 1.0, 'sigma_max': 80.0, 'sigma_min': 0.002} were passed to FlowMatchEulerDiscreteScheduler, but are not expected and will be ignored. Please verify your scheduler_config.json configuration file.\n",
49
+ "Sampling: 100%|██████████| 40/40 [00:10<00:00, 3.65it/s]\n"
50
+ ]
51
+ },
52
+ {
53
+ "name": "stdout",
54
+ "output_type": "stream",
55
+ "text": [
56
+ "Готово! Изображение сохранено как output.png\n"
57
+ ]
58
+ }
59
+ ],
60
+ "source": [
61
+ "import torch\n",
62
+ "from diffusers import Cosmos2TextToImagePipeline\n",
63
+ "\n",
64
+ "class DummyCosmosSafetyChecker:\n",
65
+ " def to(self, *args, **kwargs):\n",
66
+ " return self\n",
67
+ " \n",
68
+ " def eval(self):\n",
69
+ " return self\n",
70
+ "\n",
71
+ " # Обход проверки текста\n",
72
+ " def check_text_safety(self, prompt, *args, **kwargs):\n",
73
+ " return True\n",
74
+ "\n",
75
+ " # Обход проверки \"видео\" (картинки из 1 кадра)\n",
76
+ " def check_video_safety(self, vid, *args, **kwargs):\n",
77
+ " # Просто возвращаем тензор обратно без изменений\n",
78
+ " return vid\n",
79
+ "\n",
80
+ " # На всякий случай оставляем оригинальный __call__\n",
81
+ " def __call__(self, images, **kwargs):\n",
82
+ " return images, [False] * len(images)\n",
83
+ "\n",
84
+ "model_id = \"/workspace/sdxs-2b\"\n",
85
+ "\n",
86
+ "pipe = Cosmos2TextToImagePipeline.from_pretrained(\n",
87
+ " model_id,\n",
88
+ " safety_checker=DummyCosmosSafetyChecker(), \n",
89
+ " torch_dtype=torch.bfloat16 \n",
90
+ ")\n",
91
+ "pipe.to(\"cuda\")\n",
92
+ "\n",
93
+ "prompt = \"In a serene garden, two young girls stand side by side, their youthful energy palpable. The girl on the left, adorned with a blue dress and a matching blue flower in her hair, gazes directly at the viewer, her eyes sparkling with curiosity.\"#\"There is a young male character standing against a vibrant, colorful graffiti wall. he is wearing a hat, a jacket adorned with gold accents, and black shorts.\"\n",
94
+ "negative_prompt = \"The video captures a series of frames showing ugly scenes, static with no motion, motion blur, over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, jerky movements, low frame rate, artifacting, color banding, unnatural transitions, outdated special effects, fake elements, unconvincing visuals, poorly edited content, jump cuts, visual noise, and flickering. Overall, the video is of poor quality.\"\n",
95
+ "\n",
96
+ "# 3. Генерируем изображение\n",
97
+ "output = pipe(\n",
98
+ " height = 1024,\n",
99
+ " width=1024,\n",
100
+ " prompt=prompt, \n",
101
+ " negative_prompt=negative_prompt, \n",
102
+ " generator=torch.Generator(device=\"cuda\").manual_seed(1)\n",
103
+ ").images[0]\n",
104
+ "\n",
105
+ "output.save(\"output.png\")\n",
106
+ "print(\"Готово! Изображение сохранено как output.png\")"
107
+ ]
108
+ },
109
+ {
110
+ "cell_type": "code",
111
+ "execution_count": null,
112
+ "id": "8a173167-6c28-4bbd-8879-1375e0fd37f0",
113
+ "metadata": {},
114
+ "outputs": [],
115
+ "source": []
116
+ }
117
+ ],
118
+ "metadata": {
119
+ "kernelspec": {
120
+ "display_name": "Python3 (ipykernel)",
121
+ "language": "python",
122
+ "name": "python3"
123
+ },
124
+ "language_info": {
125
+ "codemirror_mode": {
126
+ "name": "ipython",
127
+ "version": 3
128
+ },
129
+ "file_extension": ".py",
130
+ "mimetype": "text/x-python",
131
+ "name": "python",
132
+ "nbconvert_exporter": "python",
133
+ "pygments_lexer": "ipython3",
134
+ "version": "3.12.13"
135
+ }
136
+ },
137
+ "nbformat": 4,
138
+ "nbformat_minor": 5
139
+ }
dataset.py ADDED
@@ -0,0 +1,300 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pip install flash-attn --no-build-isolation
2
+ import torch
3
+ import os
4
+ import gc
5
+ import numpy as np
6
+ import random
7
+ import json
8
+ import shutil
9
+ import time
10
+
11
+ from datasets import Dataset, load_from_disk, concatenate_datasets
12
+ from diffusers import AutoencoderKLQwenImage
13
+ from torchvision.transforms import Resize, ToTensor, Normalize, Compose, InterpolationMode, Lambda
14
+ from transformers import AutoModel, AutoImageProcessor, AutoTokenizer, AutoModelForCausalLM
15
+ from typing import Dict, List, Tuple, Optional, Any
16
+ from PIL import Image
17
+ from tqdm import tqdm
18
+ from datetime import timedelta
19
+ from accelerate import Accelerator
20
+
21
+ accelerator = Accelerator()
22
+ device = accelerator.device
23
+ is_main_process = accelerator.is_main_process
24
+ process_index = accelerator.process_index
25
+ num_processes = accelerator.num_processes
26
+
27
+ # ---------------- 1️⃣ Настройки ----------------
28
+ dtype = torch.float16
29
+ batch_size = 5
30
+ min_size = 320
31
+ max_size = 640
32
+ step = 32
33
+ empty_share = 0.0
34
+ limit = 0
35
+
36
+ folder_path = "/workspace/dataset/d23"
37
+ save_path = "/workspace/ds234_640_vae_qwen"
38
+ os.makedirs(save_path, exist_ok=True)
39
+
40
+ def clear_cuda_memory():
41
+ if torch.cuda.is_available():
42
+ used_gb = torch.cuda.max_memory_allocated() / 1024**3
43
+ print(f"[GPU {process_index}] used_gb: {used_gb:.2f} GB")
44
+ torch.cuda.empty_cache()
45
+ gc.collect()
46
+
47
+ # ---------------- 2️⃣ Загрузка моделей ----------------
48
+ def load_models():
49
+ print(f"[GPU {process_index}] Загрузка моделей...")
50
+ vae = AutoencoderKLQwenImage.from_pretrained("vae", torch_dtype=dtype).to(device).eval()
51
+ return vae
52
+
53
+ vae = load_models()
54
+
55
+ shift_factor = getattr(vae.config, "shift_factor", 0.0) or 0.0
56
+ scaling_factor = getattr(vae.config, "scaling_factor", 1.0) or 1.0
57
+
58
+ mean = getattr(vae.config, "latents_mean", None)
59
+ std = getattr(vae.config, "latents_std", None)
60
+ if mean is not None and std is not None:
61
+ latents_std = torch.tensor(std, device=device, dtype=dtype).view(1, len(std), 1, 1, 1)
62
+ latents_mean = torch.tensor(mean, device=device, dtype=dtype).view(1, len(mean), 1, 1, 1)
63
+
64
+ # ---------------- 3️⃣ Трансформации ----------------
65
+ def get_image_transform(min_size=256, max_size=512, step=64):
66
+ def transform(img, dry_run=False):
67
+ original_width, original_height = img.size
68
+
69
+ if original_width >= original_height:
70
+ new_width = max_size
71
+ new_height = int(max_size * original_height / original_width)
72
+ else:
73
+ new_height = max_size
74
+ new_width = int(max_size * original_width / original_height)
75
+
76
+ if new_height < min_size or new_width < min_size:
77
+ if original_width <= original_height:
78
+ new_width = min_size
79
+ new_height = int(min_size * original_height / original_width)
80
+ else:
81
+ new_height = min_size
82
+ new_width = int(min_size * original_width / original_height)
83
+
84
+ crop_width = min(max_size, (new_width // step) * step)
85
+ crop_height = min(max_size, (new_height // step) * step)
86
+
87
+ crop_width = max(min_size, crop_width)
88
+ crop_height = max(min_size, crop_height)
89
+
90
+ if dry_run:
91
+ return crop_width, crop_height
92
+
93
+ img_resized = img.convert("RGB").resize((new_width, new_height), Image.LANCZOS)
94
+
95
+ top = (new_height - crop_height) // 3
96
+ left = 0
97
+
98
+ img_cropped = img_resized.crop((left, top, left + crop_width, top + crop_height))
99
+
100
+ final_width, final_height = img_cropped.size
101
+
102
+ img_tensor = ToTensor()(img_cropped)
103
+ img_tensor = Normalize(mean=[0.5]*3, std=[0.5]*3)(img_tensor)
104
+ return img_tensor, img_cropped, final_width, final_height
105
+
106
+ return transform
107
+
108
+ # ---------------- 4️⃣ Функции обработки ----------------
109
+ def clean_label(label):
110
+ label = label.replace("Image 1","").replace("Image 2","").replace("Image 3","").replace("Image 4","")
111
+ label = label.replace("The image depicts ","").replace("The image presents ","")
112
+ label = label.replace("The image features ","").replace("The image portrays ","").replace("The image is ","").strip()
113
+ if label.startswith("."):
114
+ label = label[1:].lstrip()
115
+ return label
116
+
117
+ def process_labels_for_guidance(original_labels, prob_to_make_empty=0.01):
118
+ labels_for_model = []
119
+ labels_for_logging = []
120
+
121
+ for label in original_labels:
122
+ if random.random() < prob_to_make_empty:
123
+ labels_for_model.append("")
124
+ labels_for_logging.append(f"zero: {label}")
125
+ else:
126
+ labels_for_model.append(label)
127
+ labels_for_logging.append(label)
128
+
129
+ return labels_for_model, labels_for_logging
130
+
131
+ def encode_to_latents(images, texts):
132
+ transform = get_image_transform(min_size, max_size, step)
133
+
134
+ transformed_tensors = []
135
+ widths, heights = [], []
136
+
137
+ for img in images:
138
+ try:
139
+ t_img, _, w, h = transform(img)
140
+ transformed_tensors.append(t_img)
141
+ widths.append(w)
142
+ heights.append(h)
143
+ except Exception as e:
144
+ print(f"Ошибка трансформации: {e}")
145
+
146
+ if not transformed_tensors:
147
+ return None
148
+
149
+ batch_tensor = torch.stack(transformed_tensors).to(device, dtype)
150
+
151
+ if batch_tensor.ndim==4:
152
+ batch_tensor = batch_tensor.unsqueeze(2)
153
+
154
+ with torch.no_grad():
155
+ posteriors = vae.encode(batch_tensor).latent_dist.mode()
156
+ if mean is not None and std is not None:
157
+ posteriors = (posteriors - latents_mean) / latents_std
158
+ posteriors = (posteriors - shift_factor) / scaling_factor
159
+
160
+ #latents_np = posteriors.cpu().numpy()
161
+ latents_np = posteriors.squeeze(2).cpu().numpy()
162
+
163
+ text_labels = [clean_label(text) for text in texts]
164
+ _, text_labels = process_labels_for_guidance(text_labels, empty_share)
165
+
166
+ return {
167
+ "vae": latents_np,
168
+ "text": text_labels,
169
+ "width": widths,
170
+ "height": heights
171
+ }
172
+
173
+ # ---------------- 5️⃣ Обработка папки ----------------
174
+ def process_folder(folder_path, limit=None):
175
+ image_paths, text_paths, width, height = [], [], [], []
176
+ transform = get_image_transform(min_size, max_size, step)
177
+
178
+ for root, _, files in os.walk(folder_path):
179
+ for filename in files:
180
+ if filename.lower().endswith((".jpg",".jpeg",".png")):
181
+ image_path = os.path.join(root, filename)
182
+ try:
183
+ img = Image.open(image_path)
184
+ except:
185
+ continue
186
+
187
+ w,h = transform(img, dry_run=True)
188
+ text_path = os.path.splitext(image_path)[0]+".txt"
189
+
190
+ if os.path.exists(text_path):
191
+ image_paths.append(image_path)
192
+ text_paths.append(text_path)
193
+ width.append(w)
194
+ height.append(h)
195
+
196
+ print(f"Найдено {len(image_paths)} изображений")
197
+ return image_paths, text_paths, width, height
198
+
199
+ def process_in_chunks(image_paths, text_paths, width, height, chunk_size=10000, batch_size=1):
200
+ total_files = len(image_paths)
201
+ start_time = time.time()
202
+
203
+ for chunk_idx, start in enumerate(range(0,total_files,chunk_size),1):
204
+ end = min(start+chunk_size,total_files)
205
+
206
+ chunk_image_paths = image_paths[start:end]
207
+ chunk_text_paths = text_paths[start:end]
208
+ chunk_widths = width[start:end]
209
+ chunk_heights = height[start:end]
210
+
211
+ chunk_texts = []
212
+ for text_path in chunk_text_paths:
213
+ try:
214
+ with open(text_path,'r',encoding='utf-8') as f:
215
+ chunk_texts.append(f.read().strip())
216
+ except:
217
+ chunk_texts.append("")
218
+
219
+ size_groups = {}
220
+ for i in range(len(chunk_image_paths)):
221
+ key=(chunk_widths[i],chunk_heights[i])
222
+ size_groups.setdefault(key,{"image_paths":[],"texts":[]})
223
+ size_groups[key]["image_paths"].append(chunk_image_paths[i])
224
+ size_groups[key]["texts"].append(chunk_texts[i])
225
+
226
+ for size_key,group_data in size_groups.items():
227
+ group_dataset = Dataset.from_dict(group_data)
228
+
229
+ processed_group = group_dataset.map(
230
+ lambda ex: encode_to_latents(
231
+ [Image.open(p) for p in ex["image_paths"]],
232
+ #[Image.open(p).convert("RGB") for p in ex["image_paths"]], # <--- Добавил .convert("RGB"), чтобы картинка загрузилась в память
233
+ ex["texts"]
234
+ ),
235
+ batched=True,
236
+ batch_size=batch_size,
237
+ )
238
+
239
+ # --- NEW: уникальный путь ---
240
+ group_save_path = f"{save_path}_temp/chunk_{chunk_idx}_{size_key[0]}x{size_key[1]}_proc_{process_index}_"
241
+ # --- END NEW ---
242
+
243
+ processed_group.save_to_disk(group_save_path)
244
+ clear_cuda_memory()
245
+
246
+ # ---------------- 7️⃣ Объединение ----------------
247
+ def combine_chunks(temp_path, final_path):
248
+ chunks = sorted([
249
+ os.path.join(temp_path,d)
250
+ for d in os.listdir(temp_path)
251
+ if "chunk_" in d
252
+ ])
253
+
254
+ datasets = [load_from_disk(c) for c in chunks]
255
+ combined = concatenate_datasets(datasets)
256
+ combined.save_to_disk(final_path)
257
+
258
+ print("✅ Сохранено")
259
+
260
+ # ---------------- MAIN ----------------
261
+ temp_path = f"{save_path}_temp"
262
+ os.makedirs(temp_path, exist_ok=True)
263
+
264
+ image_paths, text_paths, width, height = process_folder(folder_path,limit)
265
+
266
+ # сортировка
267
+ sorted_indices = sorted(range(len(width)), key=lambda i:(width[i],height[i]))
268
+ image_paths = [image_paths[i] for i in sorted_indices]
269
+ text_paths = [text_paths[i] for i in sorted_indices]
270
+ width = [width[i] for i in sorted_indices]
271
+ height = [height[i] for i in sorted_indices]
272
+
273
+ # --- shard по GPU ---
274
+ indices = list(range(len(image_paths)))
275
+ indices = indices[process_index::num_processes]
276
+
277
+ image_paths = [image_paths[i] for i in indices]
278
+ text_paths = [text_paths[i] for i in indices]
279
+ width = [width[i] for i in indices]
280
+ height = [height[i] for i in indices]
281
+
282
+ print(f"[GPU {process_index}] обрабатывает {len(image_paths)} файлов")
283
+
284
+ process_in_chunks(image_paths, text_paths, width, height, chunk_size=5000, batch_size=batch_size)
285
+
286
+ accelerator.wait_for_everyone()
287
+
288
+ # --- NEW: только главный процесс ---
289
+ if is_main_process:
290
+ #try:
291
+ #shutil.rmtree(folder_path)
292
+ #except:
293
+ # pass
294
+
295
+ combine_chunks(temp_path, save_path)
296
+
297
+ try:
298
+ shutil.rmtree(temp_path)
299
+ except:
300
+ pass
dataset_sample.ipynb ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 3,
6
+ "id": "9c312df2-cb57-44f6-af54-3af6ab8f962f",
7
+ "metadata": {},
8
+ "outputs": [
9
+ {
10
+ "ename": "ModuleNotFoundError",
11
+ "evalue": "No module named 'numpy'",
12
+ "output_type": "error",
13
+ "traceback": [
14
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
15
+ "\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)",
16
+ "Cell \u001b[0;32mIn[3], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;66;03m#from datasets import load_from_disk\u001b[39;00m\n\u001b[0;32m----> 2\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mnumpy\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m \u001b[38;5;21;01mnp\u001b[39;00m\n\u001b[1;32m 3\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mtorch\u001b[39;00m\n\u001b[1;32m 4\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mPIL\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m Image\n",
17
+ "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'numpy'"
18
+ ]
19
+ }
20
+ ],
21
+ "source": [
22
+ "from datasets import load_from_disk\n",
23
+ "import numpy as np\n",
24
+ "import torch\n",
25
+ "from PIL import Image\n",
26
+ "from collections import defaultdict\n",
27
+ "from diffusers import AutoencoderKLQwenImage\n",
28
+ "import gc\n",
29
+ "\n",
30
+ "def analyze_dataset_by_size(dataset_path):\n",
31
+ " \"\"\"\n",
32
+ " Группирует датасет по размерам изображений и выводит базовую информацию.\n",
33
+ " \"\"\"\n",
34
+ " # Настройка устройства и типа данных\n",
35
+ " dtype = torch.float16\n",
36
+ " device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
37
+ " \n",
38
+ " # Загрузка VAE модели\n",
39
+ " print(\"Загрузка VAE модели...\")\n",
40
+ " vae = AutoencoderKLQwenImage.from_pretrained(\"vae\",torch_dtype=dtype).to(device).eval()\n",
41
+ " shift_factor = getattr(vae.config, \"shift_factor\", 0.0)\n",
42
+ " if shift_factor is None:\n",
43
+ " shift_factor = 0.0\n",
44
+ " \n",
45
+ " scaling_factor = getattr(vae.config, \"scaling_factor\", 1.0)\n",
46
+ " if scaling_factor is None:\n",
47
+ " scaling_factor = 1.0\n",
48
+ " \n",
49
+ " mean = getattr(vae.config, \"latents_mean\", None)\n",
50
+ " std = getattr(vae.config, \"latents_std\", None)\n",
51
+ " if mean is not None and std is not None:\n",
52
+ " latents_std = torch.tensor(std, device=device, dtype=dtype).view(1, len(std), 1, 1)\n",
53
+ " latents_mean = torch.tensor(mean, device=device, dtype=dtype).view(1, len(mean), 1, 1)\n",
54
+ " \n",
55
+ " # Загружаем датасет\n",
56
+ " print(f\"Загрузка датасета из {dataset_path}...\")\n",
57
+ " dataset = load_from_disk(dataset_path)\n",
58
+ "\n",
59
+ " print(f\"Осталось примеров после фильтрации: {len(dataset)}\")\n",
60
+ " \n",
61
+ " # Группируем примеры по размерам\n",
62
+ " print(\"\\nГруппировка примеров по размерам...\")\n",
63
+ " size_to_indices = defaultdict(list)\n",
64
+ " \n",
65
+ " # Собираем примеры с одинаковыми размерами\n",
66
+ " # Собираем примеры с одинаковыми размерами (оптимизированная версия)\n",
67
+ " widths = dataset[\"width\"]\n",
68
+ " heights = dataset[\"height\"]\n",
69
+ " for i, (w, h) in enumerate(zip(widths, heights)):\n",
70
+ " size_to_indices[(w, h)].append(i)\n",
71
+ " \n",
72
+ " # Сортируем размеры по количеству примеров\n",
73
+ " print(\"\\nСортируем...\")\n",
74
+ " size_stats = [(size, len(indices)) for size, indices in size_to_indices.items()]\n",
75
+ " size_stats.sort(key=lambda x: x[1], reverse=True)\n",
76
+ " \n",
77
+ " # Выводим информацию о каждой группе и показываем первый пример\n",
78
+ " for size, count in size_stats:\n",
79
+ " width, height = size\n",
80
+ " first_idx = size_to_indices[size][1]\n",
81
+ " example = dataset[first_idx]\n",
82
+ " \n",
83
+ " print(f\"\\n--- Батч {width}x{height}: {count} примеров ---\")\n",
84
+ " \n",
85
+ " # Декодируем латентное представление для первого примера\n",
86
+ " latent = torch.tensor(example[\"vae\"], dtype=dtype).unsqueeze(0).to(device)\n",
87
+ " \n",
88
+ " # 1. Снова обманываем VAE, превращая картинку в \"видео из 1 кадра\" [B, C, 1, H, W]\n",
89
+ " if latent.ndim == 4:\n",
90
+ " latent = latent.unsqueeze(2)\n",
91
+ " \n",
92
+ " with torch.no_grad():\n",
93
+ " if latents_mean is not None and latents_std is not None:\n",
94
+ " latent = latent * latents_std + latents_mean\n",
95
+ " \n",
96
+ " print(f\"Min of latent_for_vae: {latent.min()}\")\n",
97
+ " print(f\"Max of latent_for_vae: {latent.max()}\")\n",
98
+ " print(f\"Mean of latent_for_vae: {latent.mean()}\")\n",
99
+ " print(f\"Std: {latent.std().item():.4f}\")\n",
100
+ " if torch.isnan(latent).any() or torch.isinf(latent).any():\n",
101
+ " print(\"WARNING: Raw latents contain NaN or Inf values!\")\n",
102
+ " \n",
103
+ " reconstructed_image = vae.decode(latent).sample\n",
104
+ " \n",
105
+ " # 2. Вытаскиваем обычную 3D-картинку [C, H, W] из 5D-видеотензора\n",
106
+ " if reconstructed_image.ndim == 5:\n",
107
+ " # Берем нулевой батч, все каналы, нулевой кадр, всю высоту и ширину\n",
108
+ " img_tensor = reconstructed_image[0, :, 0, :, :] \n",
109
+ " else:\n",
110
+ " img_tensor = reconstructed_image.squeeze(0) # На всякий случай, если VAE вернул 4D\n",
111
+ " \n",
112
+ " img_array = img_tensor.cpu().numpy()\n",
113
+ " img_array = np.transpose(img_array, (1, 2, 0))\n",
114
+ " img_array = (img_array + 1) / 2 # Нормализация к [0, 1]\n",
115
+ " img_array = np.clip(img_array * 255, 0, 255).astype(np.uint8) # Преобразуем в uint8 для PIL\n",
116
+ " \n",
117
+ " # Создаем PIL изображение из массива\n",
118
+ " pil_image = Image.fromarray(img_array)\n",
119
+ " print(f\"Текст: {example['text']}\")\n",
120
+ " print(f\"Ключи: {', '.join(example.keys())}\")\n",
121
+ " print(f\"latent: {latent.shape}\")\n",
122
+ " pil_image.save(\"1.jpg\")\n",
123
+ " \n",
124
+ " # Очистка памяти\n",
125
+ " if torch.cuda.is_available():\n",
126
+ " torch.cuda.empty_cache()\n",
127
+ " gc.collect()\n",
128
+ " \n",
129
+ " return size_to_indices # Возвращаем словарь с индексами по группам\n",
130
+ "\n",
131
+ "# Использование\n",
132
+ "if __name__ == \"__main__\":\n",
133
+ " # Путь к датасету\n",
134
+ " save_path = \"datasets/ds234_640_vae_qwen\"\n",
135
+ " \n",
136
+ " # Анализ датасета\n",
137
+ " size_groups = analyze_dataset_by_size(save_path)"
138
+ ]
139
+ },
140
+ {
141
+ "cell_type": "code",
142
+ "execution_count": null,
143
+ "id": "74a5d11d-369f-4f25-9ee0-31d3bccd0254",
144
+ "metadata": {},
145
+ "outputs": [],
146
+ "source": []
147
+ }
148
+ ],
149
+ "metadata": {
150
+ "kernelspec": {
151
+ "display_name": "Python 3 (ipykernel)",
152
+ "language": "python",
153
+ "name": "python3"
154
+ },
155
+ "language_info": {
156
+ "codemirror_mode": {
157
+ "name": "ipython",
158
+ "version": 3
159
+ },
160
+ "file_extension": ".py",
161
+ "mimetype": "text/x-python",
162
+ "name": "python",
163
+ "nbconvert_exporter": "python",
164
+ "pygments_lexer": "ipython3",
165
+ "version": "3.12.3"
166
+ }
167
+ },
168
+ "nbformat": 4,
169
+ "nbformat_minor": 5
170
+ }
model_index.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": ["pipeline_sdxs", "SdxsPipeline"],
3
+ "_diffusers_version": "0.36.0",
4
+ "scheduler": [
5
+ "diffusers",
6
+ "FlowMatchEulerDiscreteScheduler"
7
+ ],
8
+ "text_encoder": [
9
+ "transformers",
10
+ "Qwen3_5ForConditionalGeneration"
11
+ ],
12
+ "tokenizer": [
13
+ "transformers",
14
+ "Qwen3_5Tokenizer"
15
+ ],
16
+ "transformer": [
17
+ "diffusers",
18
+ "CosmosTransformer3DModel"
19
+ ],
20
+ "vae": [
21
+ "diffusers",
22
+ "AutoencoderKLQwenImage"
23
+ ]
24
+ }
pipeline_sdxs.py ADDED
@@ -0,0 +1,348 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from PIL import Image
4
+ from typing import List, Union, Optional, Tuple
5
+ from dataclasses import dataclass
6
+
7
+ from diffusers import DiffusionPipeline
8
+ from diffusers.utils import BaseOutput
9
+ from tqdm import tqdm
10
+
11
+ @dataclass
12
+ class SdxsPipelineOutput(BaseOutput):
13
+ images: Union[List[Image.Image], np.ndarray]
14
+ prompt: Optional[Union[str, List[str]]] = None
15
+
16
+ class SdxsPipeline(DiffusionPipeline):
17
+ # Cosmos требует 512 токенов
18
+ MAX_TEXT_TOKENS = 512
19
+
20
+ def __init__(self, vae, text_encoder, tokenizer, transformer, scheduler):
21
+ super().__init__()
22
+ # Регистрируем модули (с Qwen)
23
+ self.register_modules(
24
+ vae=vae,
25
+ text_encoder=text_encoder,
26
+ tokenizer=tokenizer,
27
+ transformer=transformer,
28
+ scheduler=scheduler
29
+ )
30
+
31
+ self.vae_scale_factor = getattr(self.vae.config, "spatial_compression_ratio", 8)
32
+ if hasattr(self.vae.config, "block_out_channels"):
33
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
34
+
35
+ # Загружаем mean и std для VAE (Cosmos-style)
36
+ mean = getattr(self.vae.config, "latents_mean", None)
37
+ std = getattr(self.vae.config, "latents_std", None)
38
+ if mean is not None and std is not None:
39
+ self.vae_latents_mean = torch.tensor(mean).view(1, len(mean), 1, 1, 1)
40
+ # Внимание: Cosmos использует инвертированный std для декодирования (1.0 / std)
41
+ self.vae_latents_std = torch.tensor(std).view(1, len(std), 1, 1, 1)
42
+ else:
43
+ self.vae_latents_mean = None
44
+ self.vae_latents_std = None
45
+
46
+ # Регистрируем параметры Cosmos в шедулере (если они еще не там)
47
+ if self.scheduler is not None:
48
+ self.scheduler.register_to_config(
49
+ sigma_max=getattr(self.scheduler.config, "sigma_max", 80.0),
50
+ sigma_min=getattr(self.scheduler.config, "sigma_min", 0.002),
51
+ sigma_data=getattr(self.scheduler.config, "sigma_data", 1.0),
52
+ final_sigmas_type=getattr(self.scheduler.config, "final_sigmas_type", "sigma_min"),
53
+ )
54
+
55
+ @staticmethod
56
+ def _pad_tensor_to_length(tensor: torch.Tensor, target_len: int, dim: int = 1, pad_value: float = 0) -> torch.Tensor:
57
+ current_len = tensor.shape[dim]
58
+ if current_len >= target_len:
59
+ return tensor
60
+ pad_size = target_len - current_len
61
+ if tensor.dim() == 3:
62
+ padding = (0, 0, 0, pad_size, 0, 0)
63
+ elif tensor.dim() == 2:
64
+ padding = (0, pad_size, 0, 0)
65
+ else:
66
+ raise ValueError(f"Unsupported tensor dimension: {tensor.dim()}")
67
+ return torch.nn.functional.pad(tensor, padding, value=pad_value)
68
+
69
+ @torch.no_grad()
70
+ def refine_prompts(
71
+ self,
72
+ prompts: Union[str, List[str]],
73
+ system_prompt: Optional[str] = None,
74
+ temperature: float = 0.7
75
+ ) -> List[str]:
76
+ """Refines a list of prompts using the Text Encoder (LLM)."""
77
+ device = self.device
78
+
79
+ if system_prompt is None:
80
+ system_prompt = (
81
+ "You are a skilled text-to-image prompt engineer whose sole function is to transform "
82
+ "the user's input into an aesthetically optimized, detailed, and visually descriptive two-sentence output. "
83
+ "**The primary subject MUST be the main focus of the revised prompt "
84
+ "and MUST be described in rich detail within the first sentence.** "
85
+ "Output **only** the final revised prompt, with absolutely no commentary. "
86
+ "Don't use cliches like warm, soft, vibrant, wildflowers. Be creative. User input prompt: "
87
+ )
88
+
89
+ pad_id = getattr(self.text_encoder.config, "pad_token_id", None) or \
90
+ getattr(self.text_encoder.config, "eos_token_id", None)
91
+
92
+ prompts_list = [prompts] if isinstance(prompts, str) else prompts
93
+ refined_list = []
94
+
95
+ for p in prompts_list:
96
+ full_text = system_prompt + p
97
+ messages = [{"role": "user", "content": [{"type": "text", "text": full_text}]}]
98
+
99
+ inputs = self.tokenizer.apply_chat_template(
100
+ messages, tokenize=True, add_generation_prompt=True, return_dict=True, return_tensors="pt"
101
+ ).to(device)
102
+
103
+ generated_ids = self.text_encoder.generate(
104
+ **inputs,
105
+ max_new_tokens=self.MAX_TEXT_TOKENS,
106
+ do_sample=True,
107
+ temperature=temperature,
108
+ pad_token_id=pad_id
109
+ )
110
+
111
+ generated_ids_trimmed = [
112
+ out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
113
+ ]
114
+ output_text = self.tokenizer.batch_decode(
115
+ generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
116
+ )
117
+ refined_list.append(output_text[0])
118
+
119
+ return refined_list
120
+
121
+ @torch.no_grad()
122
+ def encode_text(self, text: Union[str, List[str]]) -> Tuple[torch.Tensor, torch.Tensor]:
123
+ """Qwen-specific text encoding (using chat_template and hidden_states[-2])"""
124
+ device = self.device
125
+ dtype = self.transformer.dtype
126
+ if text is None: text = ""
127
+ if isinstance(text, str): text = [text]
128
+
129
+ formatted_prompts = []
130
+ for t in text:
131
+ messages = [{"role": "user", "content": [{"type": "text", "text": t}]}]
132
+ formatted_prompts.append(self.tokenizer.apply_chat_template(messages, add_generation_prompt=False, tokenize=False))
133
+
134
+ toks = self.tokenizer(formatted_prompts, padding="max_length", max_length=self.MAX_TEXT_TOKENS, truncation=True, return_tensors="pt").to(device)
135
+ outputs = self.text_encoder(input_ids=toks.input_ids, attention_mask=toks.attention_mask, output_hidden_states=True)
136
+
137
+ # Берем предпоследний слой эмбеддингов, как того требуют современные пайплайны
138
+ last_hidden = outputs.hidden_states[-2]
139
+
140
+ return last_hidden.to(dtype=dtype), toks.attention_mask.to(dtype=torch.int64)
141
+
142
+ @torch.no_grad()
143
+ def image_upscale(self, image: Union[str, Image.Image, List[Union[str, Image.Image]]], batch_size: int = 1) -> List[Image.Image]:
144
+ images = [image] if isinstance(image, (str, Image.Image)) else image
145
+
146
+ batch_data = []
147
+ for img in images:
148
+ if isinstance(img, str): img = Image.open(img)
149
+ if img.mode == "RGBA":
150
+ img = Image.alpha_composite(Image.new("RGBA", img.size, (255, 255, 255)), img)
151
+ img = img.convert("RGB")
152
+
153
+ w, h = img.size
154
+ pw, ph = (8 - w % 8) % 8, (8 - h % 8) % 8
155
+ if pw or ph:
156
+ padded = Image.new("RGB", (w + pw, h + ph), (255, 255, 255))
157
+ padded.paste(img)
158
+ img = padded
159
+
160
+ t = torch.from_numpy(np.array(img).astype(np.float32) / 127.5 - 1.0).permute(2, 0, 1)
161
+ batch_data.append((t.to(self.device, torch.float16), w, h))
162
+
163
+ unique_shapes = {t.shape for t, _, _ in batch_data}
164
+ step = batch_size if len(unique_shapes) == 1 else 1
165
+
166
+ output_images = []
167
+ for i in range(0, len(batch_data), step):
168
+ chunk = batch_data[i : i + step]
169
+ tensors = torch.stack([c[0] for c in chunk]).unsqueeze(2)
170
+
171
+ latents = self.vae.encode(tensors).latent_dist.mean
172
+ decoded = self.vae.decode(latents.to(self.vae.dtype))[0]
173
+
174
+ if decoded.ndim == 5:
175
+ decoded = decoded.squeeze(2)
176
+
177
+ decoded = (decoded.clamp(-1, 1) + 1) / 2
178
+ for j, tensor in enumerate(decoded):
179
+ w, h = chunk[j][1], chunk[j][2]
180
+ arr = tensor.cpu().permute(1, 2, 0).float().numpy()
181
+ arr = arr[:h * 2, :w * 2]
182
+ output_images.append(Image.fromarray((arr * 255).astype("uint8")))
183
+
184
+ return output_images
185
+
186
+ @torch.no_grad()
187
+ def __call__(
188
+ self,
189
+ prompt: Optional[Union[str, List[str]]] = None,
190
+ negative_prompt: Optional[Union[str, List[str]]] = None,
191
+ prompt_embeds: Optional[torch.Tensor] = None,
192
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
193
+ prompt_attention_mask: Optional[torch.Tensor] = None,
194
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
195
+ latents: Optional[torch.Tensor] = None,
196
+ height: int = 1024,
197
+ width: int = 1024,
198
+ num_inference_steps: int = 40,
199
+ guidance_scale: float = 4.0,
200
+ generator: Optional[torch.Generator] = None,
201
+ seed: Optional[int] = None,
202
+ output_type: str = "pil",
203
+ return_dict: bool = True,
204
+ **kwargs,
205
+ ):
206
+ device = self.device
207
+ dtype = self.transformer.dtype
208
+
209
+ if generator is None and seed is not None:
210
+ generator = torch.Generator(device=device).manual_seed(seed)
211
+
212
+ do_classifier_free_guidance = guidance_scale > 1.0
213
+
214
+ # 1. Encode Positive
215
+ if prompt_embeds is None:
216
+ if prompt is None: raise ValueError("`prompt` or `prompt_embeds` required.")
217
+ prompt_embeds, prompt_attention_mask = self.encode_text(prompt)
218
+ prompt_embeds = prompt_embeds.to(device=device, dtype=dtype)
219
+ prompt_attention_mask = prompt_attention_mask.to(device=device, dtype=torch.int64)
220
+ batch_size = prompt_embeds.shape[0]
221
+
222
+ # 2. Encode Negative
223
+ if do_classifier_free_guidance:
224
+ if negative_prompt_embeds is None:
225
+ neg_text = negative_prompt if negative_prompt is not None else ("" if isinstance(prompt, str) else [""] * len(prompt))
226
+ negative_prompt_embeds, negative_prompt_attention_mask = self.encode_text(neg_text)
227
+
228
+ negative_prompt_embeds = negative_prompt_embeds.to(device=device, dtype=dtype)
229
+ negative_prompt_attention_mask = negative_prompt_attention_mask.to(device=device, dtype=torch.int64)
230
+
231
+ if negative_prompt_embeds.shape[0] != batch_size:
232
+ negative_prompt_embeds = negative_prompt_embeds.repeat(batch_size, 1, 1)
233
+ negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(batch_size, 1)
234
+
235
+ max_len = max(prompt_embeds.shape[1], negative_prompt_embeds.shape[1])
236
+ prompt_embeds = self._pad_tensor_to_length(prompt_embeds, max_len, dim=1, pad_value=0)
237
+ negative_prompt_embeds = self._pad_tensor_to_length(negative_prompt_embeds, max_len, dim=1, pad_value=0)
238
+ prompt_attention_mask = self._pad_tensor_to_length(prompt_attention_mask, max_len, dim=1, pad_value=0)
239
+ negative_prompt_attention_mask = self._pad_tensor_to_length(negative_prompt_attention_mask, max_len, dim=1, pad_value=0)
240
+
241
+ text_embeddings = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
242
+ else:
243
+ text_embeddings = prompt_embeds
244
+
245
+ # 3. Prepare Timesteps (Cosmos specific schedule)
246
+ sigmas_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64
247
+ sigmas = torch.linspace(0, 1, num_inference_steps, dtype=sigmas_dtype)
248
+ self.scheduler.set_timesteps(sigmas=sigmas, device=device)
249
+ timesteps = self.scheduler.timesteps
250
+
251
+ # Защита от деления на ноль на последнем шаге
252
+ if self.scheduler.config.get("final_sigmas_type", "zero") == "sigma_min":
253
+ self.scheduler.sigmas[-1] = self.scheduler.sigmas[-2]
254
+ if self.scheduler.sigmas[-1] == 0.0:
255
+ self.scheduler.sigmas[-1] = 1e-4
256
+
257
+ # 4. Prepare Latents (Noise)
258
+ latent_h = height // self.vae_scale_factor
259
+ latent_w = width // self.vae_scale_factor
260
+ in_channels = self.transformer.config.in_channels
261
+ sigma_max = getattr(self.scheduler.config, "sigma_max", 80.0)
262
+
263
+ if latents is None:
264
+ # Создаем 5D тензор [Batch, Channels, Frames, Height, Width]
265
+ latents = torch.randn((batch_size, in_channels, 1, latent_h, latent_w), generator=generator, device=device, dtype=dtype)
266
+ latents = latents * sigma_max
267
+ else:
268
+ latents = latents.to(device=device, dtype=dtype) * sigma_max
269
+
270
+ # Cosmos Padding Mask
271
+ padding_mask = torch.zeros((1, 1, height, width), device=device, dtype=dtype)
272
+
273
+ # 5. Denoising Loop (Continuous Flow Math)
274
+ for i, t in enumerate(tqdm(timesteps, desc="Sampling")):
275
+ current_sigma = self.scheduler.sigmas[i]
276
+
277
+ # Защита от деления на 0 при вычислении current_t
278
+ if current_sigma == 0.0:
279
+ current_sigma = torch.tensor(1e-4, dtype=current_sigma.dtype, device=device)
280
+
281
+ current_t = current_sigma / (current_sigma + 1.0)
282
+ c_in = 1.0 - current_t
283
+ c_skip = 1.0 - current_t
284
+ c_out = -current_t
285
+
286
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
287
+ latent_model_input = (latent_model_input * c_in).to(dtype)
288
+
289
+ # Трансформер ждет timestep в виде 1D тензора [B]
290
+ t_val = float(current_t.item()) if torch.is_tensor(current_t) else float(current_t)
291
+ timestep_tensor = torch.tensor(
292
+ [t_val],
293
+ device=device,
294
+ dtype=dtype
295
+ ).view(1, 1, 1, 1, 1).expand(latent_model_input.shape[0], 1, 1, 1, 1)
296
+
297
+ model_out = self.transformer(
298
+ hidden_states=latent_model_input,
299
+ timestep=timestep_tensor,
300
+ encoder_hidden_states=text_embeddings,
301
+ padding_mask=padding_mask,
302
+ return_dict=False,
303
+ )[0]
304
+
305
+ batched_latents = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
306
+ noise_pred = (c_skip * batched_latents + c_out * model_out.float()).to(dtype)
307
+
308
+ if do_classifier_free_guidance:
309
+ noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
310
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
311
+
312
+ noise_pred = (latents - noise_pred) / current_sigma
313
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
314
+
315
+ # 6. Decode
316
+ if output_type == "latent":
317
+ if not return_dict: return (latents, prompt)
318
+ return SdxsPipelineOutput(images=latents)
319
+
320
+ if getattr(self.vae.config, "latents_std", None) is not None and getattr(self.vae.config, "latents_mean", None) is not None:
321
+ sigma_data = getattr(self.scheduler.config, "sigma_data", 1.0)
322
+
323
+ l_mean = torch.tensor(self.vae.config.latents_mean).view(1, -1, 1, 1, 1).to(device, dtype)
324
+ l_std = torch.tensor(self.vae.config.latents_std).view(1, -1, 1, 1, 1).to(device, dtype)
325
+
326
+ # Оригинальная формула: делим на инвертированный std (что равноценно умножению на std)
327
+ #latents_std_inv = 1.0 / l_std
328
+ latents = latents * l_std + l_mean
329
+
330
+ image_output = self.vae.decode(latents.to(self.vae.dtype), return_dict=False)[0]
331
+
332
+ if image_output.ndim == 5:
333
+ image_output = image_output.squeeze(2)
334
+
335
+ image_output = (image_output.clamp(-1, 1) + 1) / 2
336
+ image_np = image_output.cpu().permute(0, 2, 3, 1).float().numpy()
337
+
338
+ # На всякий случай вычищаем NaNs
339
+ image_np = np.nan_to_num(image_np, nan=0.0, posinf=1.0, neginf=0.0)
340
+
341
+ if output_type == "pil":
342
+ images = [(Image.fromarray((img * 255).round().astype("uint8"))) for img in image_np]
343
+ else:
344
+ images = image_np
345
+
346
+ if not return_dict:
347
+ return (images,)
348
+ return SdxsPipelineOutput(images=images)
pipeline_sdxs_t5.py ADDED
@@ -0,0 +1,291 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from PIL import Image
4
+ from typing import List, Union, Optional, Tuple
5
+ from dataclasses import dataclass
6
+
7
+ from diffusers import DiffusionPipeline
8
+ from diffusers.utils import BaseOutput
9
+ from tqdm import tqdm
10
+
11
+ @dataclass
12
+ class SdxsPipelineOutput(BaseOutput):
13
+ images: Union[List[Image.Image], np.ndarray]
14
+ prompt: Optional[Union[str, List[str]]] = None
15
+
16
+ class SdxsPipeline(DiffusionPipeline):
17
+ # Cosmos требует 512 токенов
18
+ MAX_TEXT_TOKENS = 512
19
+
20
+ def __init__(self, vae, text_encoder, tokenizer, transformer, scheduler):
21
+ super().__init__()
22
+ self.register_modules(
23
+ vae=vae,
24
+ text_encoder=text_encoder,
25
+ tokenizer=tokenizer,
26
+ transformer=transformer,
27
+ scheduler=scheduler
28
+ )
29
+
30
+ self.vae_scale_factor = getattr(self.vae.config, "spatial_compression_ratio", 8)
31
+ if hasattr(self.vae.config, "block_out_channels"):
32
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
33
+
34
+ # Регистрируем параметры Cosmos в шедулере
35
+ if self.scheduler is not None:
36
+ self.scheduler.register_to_config(
37
+ sigma_max=getattr(self.scheduler.config, "sigma_max", 80.0),
38
+ sigma_min=getattr(self.scheduler.config, "sigma_min", 0.002),
39
+ sigma_data=getattr(self.scheduler.config, "sigma_data", 1.0),
40
+ final_sigmas_type=getattr(self.scheduler.config, "final_sigmas_type", "sigma_min"),
41
+ )
42
+
43
+ @staticmethod
44
+ def _pad_tensor_to_length(tensor: torch.Tensor, target_len: int, dim: int = 1, pad_value: float = 0) -> torch.Tensor:
45
+ current_len = tensor.shape[dim]
46
+ if current_len >= target_len:
47
+ return tensor
48
+ pad_size = target_len - current_len
49
+ if tensor.dim() == 3:
50
+ padding = (0, 0, 0, pad_size, 0, 0)
51
+ elif tensor.dim() == 2:
52
+ padding = (0, pad_size, 0, 0)
53
+ else:
54
+ raise ValueError(f"Unsupported tensor dimension: {tensor.dim()}")
55
+ return torch.nn.functional.pad(tensor, padding, value=pad_value)
56
+
57
+ @torch.no_grad()
58
+ def refine_prompts(
59
+ self,
60
+ prompts: Union[str, List[str]],
61
+ system_prompt: Optional[str] = None,
62
+ temperature: float = 0.7
63
+ ) -> List[str]:
64
+ return [prompts] if isinstance(prompts, str) else prompts
65
+
66
+ @torch.no_grad()
67
+ def encode_text(self, text: Union[str, List[str]]) -> Tuple[torch.Tensor, torch.Tensor]:
68
+ device = self.device
69
+ dtype = self.transformer.dtype
70
+ if text is None: text = ""
71
+ if isinstance(text, str): text = [text]
72
+
73
+ text_inputs = self.tokenizer(
74
+ text,
75
+ padding="max_length",
76
+ max_length=self.MAX_TEXT_TOKENS,
77
+ truncation=True,
78
+ return_tensors="pt"
79
+ )
80
+
81
+ text_input_ids = text_inputs.input_ids.to(device)
82
+ attention_mask = text_inputs.attention_mask.to(device)
83
+
84
+ outputs = self.text_encoder(input_ids=text_input_ids, attention_mask=attention_mask)
85
+ prompt_embeds = outputs.last_hidden_state
86
+
87
+ lengths = attention_mask.sum(dim=1)
88
+ for i, length in enumerate(lengths):
89
+ prompt_embeds[i, length:] = 0
90
+
91
+ return prompt_embeds.to(dtype=dtype), attention_mask.to(dtype=torch.int64)
92
+
93
+ @torch.no_grad()
94
+ def image_upscale(self, image: Union[str, Image.Image, List[Union[str, Image.Image]]], batch_size: int = 1) -> List[Image.Image]:
95
+ images = [image] if isinstance(image, (str, Image.Image)) else image
96
+
97
+ batch_data = []
98
+ for img in images:
99
+ if isinstance(img, str): img = Image.open(img)
100
+ if img.mode == "RGBA":
101
+ img = Image.alpha_composite(Image.new("RGBA", img.size, (255, 255, 255)), img)
102
+ img = img.convert("RGB")
103
+
104
+ w, h = img.size
105
+ pw, ph = (8 - w % 8) % 8, (8 - h % 8) % 8
106
+ if pw or ph:
107
+ padded = Image.new("RGB", (w + pw, h + ph), (255, 255, 255))
108
+ padded.paste(img)
109
+ img = padded
110
+
111
+ t = torch.from_numpy(np.array(img).astype(np.float32) / 127.5 - 1.0).permute(2, 0, 1)
112
+ batch_data.append((t.to(self.device, torch.float16), w, h))
113
+
114
+ unique_shapes = {t.shape for t, _, _ in batch_data}
115
+ step = batch_size if len(unique_shapes) == 1 else 1
116
+
117
+ output_images = []
118
+ for i in range(0, len(batch_data), step):
119
+ chunk = batch_data[i : i + step]
120
+ tensors = torch.stack([c[0] for c in chunk]).unsqueeze(2)
121
+
122
+ latents = self.vae.encode(tensors).latent_dist.mean
123
+ decoded = self.vae.decode(latents.to(self.vae.dtype))[0]
124
+
125
+ if decoded.ndim == 5:
126
+ decoded = decoded.squeeze(2)
127
+
128
+ decoded = (decoded.clamp(-1, 1) + 1) / 2
129
+ for j, tensor in enumerate(decoded):
130
+ w, h = chunk[j][1], chunk[j][2]
131
+ arr = tensor.cpu().permute(1, 2, 0).float().numpy()
132
+ arr = arr[:h * 2, :w * 2]
133
+ output_images.append(Image.fromarray((arr * 255).astype("uint8")))
134
+
135
+ return output_images
136
+
137
+ @torch.no_grad()
138
+ def __call__(
139
+ self,
140
+ prompt: Optional[Union[str, List[str]]] = None,
141
+ negative_prompt: Optional[Union[str, List[str]]] = None,
142
+ prompt_embeds: Optional[torch.Tensor] = None,
143
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
144
+ prompt_attention_mask: Optional[torch.Tensor] = None,
145
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
146
+ latents: Optional[torch.Tensor] = None,
147
+ height: int = 1024,
148
+ width: int = 1024,
149
+ num_inference_steps: int = 40,
150
+ guidance_scale: float = 7.0,
151
+ generator: Optional[torch.Generator] = None,
152
+ seed: Optional[int] = None,
153
+ output_type: str = "pil",
154
+ return_dict: bool = True,
155
+ **kwargs,
156
+ ):
157
+ device = self.device
158
+ dtype = self.transformer.dtype
159
+
160
+ if generator is None and seed is not None:
161
+ generator = torch.Generator(device=device).manual_seed(seed)
162
+
163
+ do_classifier_free_guidance = guidance_scale > 1.0
164
+
165
+ # 1. Encode Positive
166
+ if prompt_embeds is None:
167
+ if prompt is None: raise ValueError("`prompt` or `prompt_embeds` required.")
168
+ prompt_embeds, _ = self.encode_text(prompt)
169
+ prompt_embeds = prompt_embeds.to(device=device, dtype=dtype)
170
+ batch_size = prompt_embeds.shape[0]
171
+
172
+ # 2. Encode Negative
173
+ if do_classifier_free_guidance:
174
+ if negative_prompt_embeds is None:
175
+ neg_text = negative_prompt if negative_prompt is not None else ("" if isinstance(prompt, str) else [""] * len(prompt))
176
+ negative_prompt_embeds, _ = self.encode_text(neg_text)
177
+
178
+ negative_prompt_embeds = negative_prompt_embeds.to(device=device, dtype=dtype)
179
+
180
+ if negative_prompt_embeds.shape[0] != batch_size:
181
+ negative_prompt_embeds = negative_prompt_embeds.repeat(batch_size, 1, 1)
182
+
183
+ max_len = max(prompt_embeds.shape[1], negative_prompt_embeds.shape[1])
184
+ prompt_embeds = self._pad_tensor_to_length(prompt_embeds, max_len, dim=1, pad_value=0)
185
+ negative_prompt_embeds = self._pad_tensor_to_length(negative_prompt_embeds, max_len, dim=1, pad_value=0)
186
+
187
+ # 3. Prepare Timesteps
188
+ sigmas_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64
189
+ sigmas = torch.linspace(0, 1, num_inference_steps, dtype=sigmas_dtype)
190
+ self.scheduler.set_timesteps(sigmas=sigmas, device=device)
191
+ timesteps = self.scheduler.timesteps
192
+
193
+ # Защита от деления на ноль на последнем шаге
194
+ if self.scheduler.config.get("final_sigmas_type", "zero") == "sigma_min":
195
+ self.scheduler.sigmas[-1] = self.scheduler.sigmas[-2]
196
+ if self.scheduler.sigmas[-1] == 0.0:
197
+ self.scheduler.sigmas[-1] = 1e-4
198
+
199
+ # 4. Prepare Latents (Noise)
200
+ latent_h = height // self.vae_scale_factor
201
+ latent_w = width // self.vae_scale_factor
202
+ in_channels = self.transformer.config.in_channels
203
+ sigma_max = getattr(self.scheduler.config, "sigma_max", 80.0)
204
+
205
+ if latents is None:
206
+ latents = torch.randn((batch_size, in_channels, 1, latent_h, latent_w), generator=generator, device=device, dtype=dtype)
207
+ latents = latents * sigma_max
208
+ else:
209
+ latents = latents.to(device=device, dtype=dtype) * sigma_max
210
+
211
+ # Cosmos Padding Mask
212
+ padding_mask = latents.new_zeros(1, 1, height, width, dtype=dtype)
213
+
214
+ # 5. Denoising Loop
215
+ for i, t in enumerate(tqdm(timesteps, desc="Sampling")):
216
+ current_sigma = self.scheduler.sigmas[i]
217
+
218
+ # Защита от деления на 0 при вычислении current_t
219
+ if current_sigma == 0.0:
220
+ current_sigma = torch.tensor(1e-4, dtype=current_sigma.dtype, device=device)
221
+
222
+ current_t = current_sigma / (current_sigma + 1.0)
223
+ c_in = 1.0 - current_t
224
+ c_skip = 1.0 - current_t
225
+ c_out = -current_t
226
+
227
+ latent_model_input = (latents * c_in).to(dtype)
228
+ timestep = current_t.expand(latents.shape[0]).to(dtype)
229
+
230
+ # Проход 1
231
+ noise_pred = self.transformer(
232
+ hidden_states=latent_model_input,
233
+ timestep=timestep,
234
+ encoder_hidden_states=prompt_embeds,
235
+ padding_mask=padding_mask,
236
+ return_dict=False,
237
+ )[0]
238
+
239
+ noise_pred = (c_skip * latents + c_out * noise_pred.float()).to(dtype)
240
+
241
+ # Проход 2
242
+ if do_classifier_free_guidance:
243
+ noise_pred_uncond = self.transformer(
244
+ hidden_states=latent_model_input,
245
+ timestep=timestep,
246
+ encoder_hidden_states=negative_prompt_embeds,
247
+ padding_mask=padding_mask,
248
+ return_dict=False,
249
+ )[0]
250
+
251
+ noise_pred_uncond = (c_skip * latents + c_out * noise_pred_uncond.float()).to(dtype)
252
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred - noise_pred_uncond)
253
+
254
+ noise_pred = (latents - noise_pred) / current_sigma
255
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
256
+
257
+ # 6. Decode
258
+ if output_type == "latent":
259
+ if not return_dict: return (latents, prompt)
260
+ return SdxsPipelineOutput(images=latents)
261
+
262
+ # Точная математика NVIDIA для декодирования (без двойных инверсий)
263
+ if getattr(self.vae.config, "latents_std", None) is not None and getattr(self.vae.config, "latents_mean", None) is not None:
264
+ sigma_data = getattr(self.scheduler.config, "sigma_data", 1.0)
265
+
266
+ l_mean = torch.tensor(self.vae.config.latents_mean).view(1, -1, 1, 1, 1).to(device, dtype)
267
+ l_std = torch.tensor(self.vae.config.latents_std).view(1, -1, 1, 1, 1).to(device, dtype)
268
+
269
+ # Оригинальная формула: делим на инвертированный std (что равноценно умножению на std)
270
+ latents_std_inv = 1.0 / l_std
271
+ latents = latents / latents_std_inv / sigma_data + l_mean
272
+
273
+ image_output = self.vae.decode(latents.to(self.vae.dtype), return_dict=False)[0]
274
+
275
+ if image_output.ndim == 5:
276
+ image_output = image_output.squeeze(2)
277
+
278
+ image_output = (image_output.clamp(-1, 1) + 1) / 2
279
+ image_np = image_output.cpu().permute(0, 2, 3, 1).float().numpy()
280
+
281
+ # На всякий случай вычищаем NaNs, если они проскользнули, чтобы скрипт не падал с кастом
282
+ image_np = np.nan_to_num(image_np, nan=0.0, posinf=1.0, neginf=0.0)
283
+
284
+ if output_type == "pil":
285
+ images = [(Image.fromarray((img * 255).round().astype("uint8"))) for img in image_np]
286
+ else:
287
+ images = image_np
288
+
289
+ if not return_dict:
290
+ return (images,)
291
+ return SdxsPipelineOutput(images=images)
scheduler/.ipynb_checkpoints/scheduler_config-checkpoint.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "FlowMatchEulerDiscreteScheduler",
3
+ "_diffusers_version": "0.34.0.dev0",
4
+ "base_image_seq_len": 256,
5
+ "base_shift": 0.5,
6
+ "final_sigmas_type": "sigma_min",
7
+ "invert_sigmas": false,
8
+ "max_image_seq_len": 4096,
9
+ "max_shift": 1.15,
10
+ "num_train_timesteps": 1000,
11
+ "shift": 1.0,
12
+ "shift_terminal": null,
13
+ "sigma_data": 1.0,
14
+ "sigma_max": 80.0,
15
+ "sigma_min": 0.002,
16
+ "stochastic_sampling": false,
17
+ "time_shift_type": "exponential",
18
+ "use_beta_sigmas": false,
19
+ "use_dynamic_shifting": false,
20
+ "use_exponential_sigmas": false,
21
+ "use_karras_sigmas": true
22
+ }
scheduler/scheduler_config.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "FlowMatchEulerDiscreteScheduler",
3
+ "_diffusers_version": "0.34.0.dev0",
4
+ "base_image_seq_len": 256,
5
+ "base_shift": 0.5,
6
+ "final_sigmas_type": "sigma_min",
7
+ "invert_sigmas": false,
8
+ "max_image_seq_len": 4096,
9
+ "max_shift": 1.15,
10
+ "num_train_timesteps": 1000,
11
+ "shift": 1.0,
12
+ "shift_terminal": null,
13
+ "sigma_data": 1.0,
14
+ "sigma_max": 80.0,
15
+ "sigma_min": 0.002,
16
+ "stochastic_sampling": false,
17
+ "time_shift_type": "exponential",
18
+ "use_beta_sigmas": false,
19
+ "use_dynamic_shifting": false,
20
+ "use_exponential_sigmas": false,
21
+ "use_karras_sigmas": true
22
+ }
t.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import load_from_disk
2
+ import numpy as np
3
+ import torch
4
+ from PIL import Image
5
+ from collections import defaultdict
6
+ from diffusers import AutoencoderKLQwenImage
7
+ import gc
8
+
9
+ def analyze_dataset_by_size(dataset_path):
10
+ """
11
+ Группирует датасет по размерам изображений и выводит базовую информацию.
12
+ """
13
+ # Настройка устройства и типа данных
14
+ dtype = torch.float32
15
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
+
17
+ # Загрузка VAE модели
18
+ print("Загрузка VAE модели...")
19
+ vae = AutoencoderKLQwenImage.from_pretrained("vae",torch_dtype=dtype).to(device).eval()
20
+ shift_factor = getattr(vae.config, "shift_factor", 0.0)
21
+ if shift_factor is None:
22
+ shift_factor = 0.0
23
+
24
+ scaling_factor = getattr(vae.config, "scaling_factor", 1.0)
25
+ if scaling_factor is None:
26
+ scaling_factor = 1.0
27
+
28
+ mean = getattr(vae.config, "latents_mean", None)
29
+ std = getattr(vae.config, "latents_std", None)
30
+ if mean is not None and std is not None:
31
+ latents_std = torch.tensor(std, device=device, dtype=dtype).view(1, len(std), 1, 1)
32
+ latents_mean = torch.tensor(mean, device=device, dtype=dtype).view(1, len(mean), 1, 1)
33
+
34
+ # Загружаем датасет
35
+ print(f"Загрузка датасета из {dataset_path}...")
36
+ dataset = load_from_disk(dataset_path)
37
+
38
+ print(f"Осталось примеров после фильтрации: {len(dataset)}")
39
+
40
+ # Группируем примеры по размерам
41
+ print("\nГруппировка примеров по размерам...")
42
+ size_to_indices = defaultdict(list)
43
+
44
+ # Собираем примеры с одинаковыми размерами
45
+ # Собираем примеры с одинаковыми размерами (оптимизированная версия)
46
+ widths = dataset["width"]
47
+ heights = dataset["height"]
48
+ for i, (w, h) in enumerate(zip(widths, heights)):
49
+ size_to_indices[(w, h)].append(i)
50
+
51
+ # Сортируем размеры по количеству примеров
52
+ print("\nСортируем...")
53
+ size_stats = [(size, len(indices)) for size, indices in size_to_indices.items()]
54
+ size_stats.sort(key=lambda x: x[1], reverse=True)
55
+
56
+ # Выводим информацию о каждой группе и показываем первый пример
57
+ for size, count in size_stats:
58
+ width, height = size
59
+ first_idx = size_to_indices[size][1]
60
+ example = dataset[first_idx]
61
+
62
+ print(f"\n--- Батч {width}x{height}: {count} примеров ---")
63
+
64
+ # Декодируем латентное представление для первого примера
65
+ latent = torch.tensor(example["vae"], dtype=dtype).unsqueeze(0).to(device)
66
+
67
+ # 1. Снова обманываем VAE, превращая картинку в "видео из 1 кадра" [B, C, 1, H, W]
68
+ if latent.ndim == 4:
69
+ latent = latent.unsqueeze(2)
70
+
71
+ with torch.no_grad():
72
+ if latents_mean is not None and latents_std is not None:
73
+ latent = latent * latents_std + latents_mean
74
+
75
+ print(f"Min of latent_for_vae: {latent.min()}")
76
+ print(f"Max of latent_for_vae: {latent.max()}")
77
+ print(f"Mean of latent_for_vae: {latent.mean()}")
78
+ print(f"Std: {latent.std().item():.4f}")
79
+ if torch.isnan(latent).any() or torch.isinf(latent).any():
80
+ print("WARNING: Raw latents contain NaN or Inf values!")
81
+
82
+ reconstructed_image = vae.decode(latent).sample
83
+
84
+ # 2. Вытаскиваем обычную 3D-картинку [C, H, W] из 5D-видеотензора
85
+ if reconstructed_image.ndim == 5:
86
+ # Берем нулевой батч, все каналы, нулевой кадр, всю высоту и ширину
87
+ img_tensor = reconstructed_image[0, :, 0, :, :]
88
+ else:
89
+ img_tensor = reconstructed_image.squeeze(0) # На всякий случай, если VAE вернул 4D
90
+
91
+ img_array = img_tensor.cpu().numpy()
92
+ img_array = np.transpose(img_array, (1, 2, 0))
93
+ img_array = (img_array + 1) / 2 # Нормализация к [0, 1]
94
+ img_array = np.clip(img_array * 255, 0, 255).astype(np.uint8) # Преобразуем в uint8 для PIL
95
+
96
+ # Создаем PIL изображение из массива
97
+ pil_image = Image.fromarray(img_array)
98
+ print(f"Текст: {example['text']}")
99
+ print(f"Ключи: {', '.join(example.keys())}")
100
+ print(f"latent: {latent.shape}")
101
+ pil_image.save("1.jpg")
102
+
103
+ # Очистка памяти
104
+ if torch.cuda.is_available():
105
+ torch.cuda.empty_cache()
106
+ gc.collect()
107
+
108
+ return size_to_indices # Возвращаем словарь с индексами по группам
109
+
110
+ # Использование
111
+ if __name__ == "__main__":
112
+ # Путь к датасету
113
+ save_path = "datasets/ds234_640_vae_qwen"
114
+
115
+ # Анализ датасета
116
+ size_groups = analyze_dataset_by_size(save_path)
test.ipynb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:677906d20fb691440965fb107de2c9d8e9b7c75884d9e3e15b4375f4257df8ae
3
+ size 21416092
text_encoder/.ipynb_checkpoints/config-checkpoint.json ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "Qwen3_5Model"
4
+ ],
5
+ "dtype": "bfloat16",
6
+ "image_token_id": 248056,
7
+ "model_type": "qwen3_5",
8
+ "text_config": {
9
+ "attention_bias": false,
10
+ "attention_dropout": 0.0,
11
+ "attn_output_gate": true,
12
+ "bos_token_id": null,
13
+ "dtype": "bfloat16",
14
+ "eos_token_id": 248044,
15
+ "full_attention_interval": 4,
16
+ "head_dim": 256,
17
+ "hidden_act": "silu",
18
+ "hidden_size": 1024,
19
+ "initializer_range": 0.02,
20
+ "intermediate_size": 3584,
21
+ "layer_types": [
22
+ "linear_attention",
23
+ "linear_attention",
24
+ "linear_attention",
25
+ "full_attention",
26
+ "linear_attention",
27
+ "linear_attention",
28
+ "linear_attention",
29
+ "full_attention",
30
+ "linear_attention",
31
+ "linear_attention",
32
+ "linear_attention",
33
+ "full_attention",
34
+ "linear_attention",
35
+ "linear_attention",
36
+ "linear_attention",
37
+ "full_attention",
38
+ "linear_attention",
39
+ "linear_attention",
40
+ "linear_attention",
41
+ "full_attention",
42
+ "linear_attention",
43
+ "linear_attention",
44
+ "linear_attention",
45
+ "full_attention"
46
+ ],
47
+ "linear_conv_kernel_dim": 4,
48
+ "linear_key_head_dim": 128,
49
+ "linear_num_key_heads": 16,
50
+ "linear_num_value_heads": 16,
51
+ "linear_value_head_dim": 128,
52
+ "mamba_ssm_dtype": "float32",
53
+ "max_position_embeddings": 262144,
54
+ "mlp_only_layers": [],
55
+ "model_type": "qwen3_5_text",
56
+ "mtp_num_hidden_layers": 1,
57
+ "mtp_use_dedicated_embeddings": false,
58
+ "num_attention_heads": 8,
59
+ "num_hidden_layers": 24,
60
+ "num_key_value_heads": 2,
61
+ "pad_token_id": null,
62
+ "partial_rotary_factor": 0.25,
63
+ "rms_norm_eps": 1e-06,
64
+ "rope_parameters": {
65
+ "mrope_interleaved": true,
66
+ "mrope_section": [
67
+ 11,
68
+ 11,
69
+ 10
70
+ ],
71
+ "partial_rotary_factor": 0.25,
72
+ "rope_theta": 10000000,
73
+ "rope_type": "default"
74
+ },
75
+ "tie_word_embeddings": true,
76
+ "use_cache": true,
77
+ "vocab_size": 248320
78
+ },
79
+ "tie_word_embeddings": true,
80
+ "transformers_version": "5.6.1",
81
+ "video_token_id": 248057,
82
+ "vision_config": {
83
+ "deepstack_visual_indexes": [],
84
+ "depth": 12,
85
+ "dtype": "bfloat16",
86
+ "hidden_act": "gelu_pytorch_tanh",
87
+ "hidden_size": 768,
88
+ "in_channels": 3,
89
+ "initializer_range": 0.02,
90
+ "intermediate_size": 3072,
91
+ "model_type": "qwen3_5_vision",
92
+ "num_heads": 12,
93
+ "num_position_embeddings": 2304,
94
+ "out_hidden_size": 1024,
95
+ "patch_size": 16,
96
+ "spatial_merge_size": 2,
97
+ "temporal_patch_size": 2
98
+ },
99
+ "vision_end_token_id": 248054,
100
+ "vision_start_token_id": 248053
101
+ }
text_encoder/config.json ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "Qwen3_5Model"
4
+ ],
5
+ "dtype": "bfloat16",
6
+ "image_token_id": 248056,
7
+ "model_type": "qwen3_5",
8
+ "text_config": {
9
+ "attention_bias": false,
10
+ "attention_dropout": 0.0,
11
+ "attn_output_gate": true,
12
+ "bos_token_id": null,
13
+ "dtype": "bfloat16",
14
+ "eos_token_id": 248044,
15
+ "full_attention_interval": 4,
16
+ "head_dim": 256,
17
+ "hidden_act": "silu",
18
+ "hidden_size": 1024,
19
+ "initializer_range": 0.02,
20
+ "intermediate_size": 3584,
21
+ "layer_types": [
22
+ "linear_attention",
23
+ "linear_attention",
24
+ "linear_attention",
25
+ "full_attention",
26
+ "linear_attention",
27
+ "linear_attention",
28
+ "linear_attention",
29
+ "full_attention",
30
+ "linear_attention",
31
+ "linear_attention",
32
+ "linear_attention",
33
+ "full_attention",
34
+ "linear_attention",
35
+ "linear_attention",
36
+ "linear_attention",
37
+ "full_attention",
38
+ "linear_attention",
39
+ "linear_attention",
40
+ "linear_attention",
41
+ "full_attention",
42
+ "linear_attention",
43
+ "linear_attention",
44
+ "linear_attention",
45
+ "full_attention"
46
+ ],
47
+ "linear_conv_kernel_dim": 4,
48
+ "linear_key_head_dim": 128,
49
+ "linear_num_key_heads": 16,
50
+ "linear_num_value_heads": 16,
51
+ "linear_value_head_dim": 128,
52
+ "mamba_ssm_dtype": "float32",
53
+ "max_position_embeddings": 262144,
54
+ "mlp_only_layers": [],
55
+ "model_type": "qwen3_5_text",
56
+ "mtp_num_hidden_layers": 1,
57
+ "mtp_use_dedicated_embeddings": false,
58
+ "num_attention_heads": 8,
59
+ "num_hidden_layers": 24,
60
+ "num_key_value_heads": 2,
61
+ "pad_token_id": null,
62
+ "partial_rotary_factor": 0.25,
63
+ "rms_norm_eps": 1e-06,
64
+ "rope_parameters": {
65
+ "mrope_interleaved": true,
66
+ "mrope_section": [
67
+ 11,
68
+ 11,
69
+ 10
70
+ ],
71
+ "partial_rotary_factor": 0.25,
72
+ "rope_theta": 10000000,
73
+ "rope_type": "default"
74
+ },
75
+ "tie_word_embeddings": true,
76
+ "use_cache": true,
77
+ "vocab_size": 248320
78
+ },
79
+ "tie_word_embeddings": true,
80
+ "transformers_version": "5.6.1",
81
+ "video_token_id": 248057,
82
+ "vision_config": {
83
+ "deepstack_visual_indexes": [],
84
+ "depth": 12,
85
+ "dtype": "bfloat16",
86
+ "hidden_act": "gelu_pytorch_tanh",
87
+ "hidden_size": 768,
88
+ "in_channels": 3,
89
+ "initializer_range": 0.02,
90
+ "intermediate_size": 3072,
91
+ "model_type": "qwen3_5_vision",
92
+ "num_heads": 12,
93
+ "num_position_embeddings": 2304,
94
+ "out_hidden_size": 1024,
95
+ "patch_size": 16,
96
+ "spatial_merge_size": 2,
97
+ "temporal_patch_size": 2
98
+ },
99
+ "vision_end_token_id": 248054,
100
+ "vision_start_token_id": 248053
101
+ }
text_encoder/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:be05a6e8dcacdae04865491110f227b71229110e321aa655982c4bd793ea411a
3
+ size 1706027688
tokenizer/chat_template.jinja ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {%- set image_count = namespace(value=0) %}
2
+ {%- set video_count = namespace(value=0) %}
3
+ {%- macro render_content(content, do_vision_count, is_system_content=false) %}
4
+ {%- if content is string %}
5
+ {{- content }}
6
+ {%- elif content is iterable and content is not mapping %}
7
+ {%- for item in content %}
8
+ {%- if 'image' in item or 'image_url' in item or item.type == 'image' %}
9
+ {%- if is_system_content %}
10
+ {{- raise_exception('System message cannot contain images.') }}
11
+ {%- endif %}
12
+ {%- if do_vision_count %}
13
+ {%- set image_count.value = image_count.value + 1 %}
14
+ {%- endif %}
15
+ {%- if add_vision_id %}
16
+ {{- 'Picture ' ~ image_count.value ~ ': ' }}
17
+ {%- endif %}
18
+ {{- '<|vision_start|><|image_pad|><|vision_end|>' }}
19
+ {%- elif 'video' in item or item.type == 'video' %}
20
+ {%- if is_system_content %}
21
+ {{- raise_exception('System message cannot contain videos.') }}
22
+ {%- endif %}
23
+ {%- if do_vision_count %}
24
+ {%- set video_count.value = video_count.value + 1 %}
25
+ {%- endif %}
26
+ {%- if add_vision_id %}
27
+ {{- 'Video ' ~ video_count.value ~ ': ' }}
28
+ {%- endif %}
29
+ {{- '<|vision_start|><|video_pad|><|vision_end|>' }}
30
+ {%- elif 'text' in item %}
31
+ {{- item.text }}
32
+ {%- else %}
33
+ {{- raise_exception('Unexpected item type in content.') }}
34
+ {%- endif %}
35
+ {%- endfor %}
36
+ {%- elif content is none or content is undefined %}
37
+ {{- '' }}
38
+ {%- else %}
39
+ {{- raise_exception('Unexpected content type.') }}
40
+ {%- endif %}
41
+ {%- endmacro %}
42
+ {%- if not messages %}
43
+ {{- raise_exception('No messages provided.') }}
44
+ {%- endif %}
45
+ {%- if tools and tools is iterable and tools is not mapping %}
46
+ {{- '<|im_start|>system\n' }}
47
+ {{- "# Tools\n\nYou have access to the following functions:\n\n<tools>" }}
48
+ {%- for tool in tools %}
49
+ {{- "\n" }}
50
+ {{- tool | tojson }}
51
+ {%- endfor %}
52
+ {{- "\n</tools>" }}
53
+ {{- '\n\nIf you choose to call a function ONLY reply in the following format with NO suffix:\n\n<tool_call>\n<function=example_function_name>\n<parameter=example_parameter_1>\nvalue_1\n</parameter>\n<parameter=example_parameter_2>\nThis is the value for the second parameter\nthat can span\nmultiple lines\n</parameter>\n</function>\n</tool_call>\n\n<IMPORTANT>\nReminder:\n- Function calls MUST follow the specified format: an inner <function=...></function> block must be nested within <tool_call></tool_call> XML tags\n- Required parameters MUST be specified\n- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after\n- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls\n</IMPORTANT>' }}
54
+ {%- if messages[0].role == 'system' %}
55
+ {%- set content = render_content(messages[0].content, false, true)|trim %}
56
+ {%- if content %}
57
+ {{- '\n\n' + content }}
58
+ {%- endif %}
59
+ {%- endif %}
60
+ {{- '<|im_end|>\n' }}
61
+ {%- else %}
62
+ {%- if messages[0].role == 'system' %}
63
+ {%- set content = render_content(messages[0].content, false, true)|trim %}
64
+ {{- '<|im_start|>system\n' + content + '<|im_end|>\n' }}
65
+ {%- endif %}
66
+ {%- endif %}
67
+ {%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}
68
+ {%- for message in messages[::-1] %}
69
+ {%- set index = (messages|length - 1) - loop.index0 %}
70
+ {%- if ns.multi_step_tool and message.role == "user" %}
71
+ {%- set content = render_content(message.content, false)|trim %}
72
+ {%- if not(content.startswith('<tool_response>') and content.endswith('</tool_response>')) %}
73
+ {%- set ns.multi_step_tool = false %}
74
+ {%- set ns.last_query_index = index %}
75
+ {%- endif %}
76
+ {%- endif %}
77
+ {%- endfor %}
78
+ {%- if ns.multi_step_tool %}
79
+ {{- raise_exception('No user query found in messages.') }}
80
+ {%- endif %}
81
+ {%- for message in messages %}
82
+ {%- set content = render_content(message.content, true)|trim %}
83
+ {%- if message.role == "system" %}
84
+ {%- if not loop.first %}
85
+ {{- raise_exception('System message must be at the beginning.') }}
86
+ {%- endif %}
87
+ {%- elif message.role == "user" %}
88
+ {{- '<|im_start|>' + message.role + '\n' + content + '<|im_end|>' + '\n' }}
89
+ {%- elif message.role == "assistant" %}
90
+ {%- set reasoning_content = '' %}
91
+ {%- if message.reasoning_content is string %}
92
+ {%- set reasoning_content = message.reasoning_content %}
93
+ {%- else %}
94
+ {%- if '</think>' in content %}
95
+ {%- set reasoning_content = content.split('</think>')[0].rstrip('\n').split('<think>')[-1].lstrip('\n') %}
96
+ {%- set content = content.split('</think>')[-1].lstrip('\n') %}
97
+ {%- endif %}
98
+ {%- endif %}
99
+ {%- set reasoning_content = reasoning_content|trim %}
100
+ {%- if loop.index0 > ns.last_query_index %}
101
+ {{- '<|im_start|>' + message.role + '\n<think>\n' + reasoning_content + '\n</think>\n\n' + content }}
102
+ {%- else %}
103
+ {{- '<|im_start|>' + message.role + '\n' + content }}
104
+ {%- endif %}
105
+ {%- if message.tool_calls and message.tool_calls is iterable and message.tool_calls is not mapping %}
106
+ {%- for tool_call in message.tool_calls %}
107
+ {%- if tool_call.function is defined %}
108
+ {%- set tool_call = tool_call.function %}
109
+ {%- endif %}
110
+ {%- if loop.first %}
111
+ {%- if content|trim %}
112
+ {{- '\n\n<tool_call>\n<function=' + tool_call.name + '>\n' }}
113
+ {%- else %}
114
+ {{- '<tool_call>\n<function=' + tool_call.name + '>\n' }}
115
+ {%- endif %}
116
+ {%- else %}
117
+ {{- '\n<tool_call>\n<function=' + tool_call.name + '>\n' }}
118
+ {%- endif %}
119
+ {%- if tool_call.arguments is defined %}
120
+ {%- for args_name, args_value in tool_call.arguments|items %}
121
+ {{- '<parameter=' + args_name + '>\n' }}
122
+ {%- set args_value = args_value | tojson | safe if args_value is mapping or (args_value is sequence and args_value is not string) else args_value | string %}
123
+ {{- args_value }}
124
+ {{- '\n</parameter>\n' }}
125
+ {%- endfor %}
126
+ {%- endif %}
127
+ {{- '</function>\n</tool_call>' }}
128
+ {%- endfor %}
129
+ {%- endif %}
130
+ {{- '<|im_end|>\n' }}
131
+ {%- elif message.role == "tool" %}
132
+ {%- if loop.previtem and loop.previtem.role != "tool" %}
133
+ {{- '<|im_start|>user' }}
134
+ {%- endif %}
135
+ {{- '\n<tool_response>\n' }}
136
+ {{- content }}
137
+ {{- '\n</tool_response>' }}
138
+ {%- if not loop.last and loop.nextitem.role != "tool" %}
139
+ {{- '<|im_end|>\n' }}
140
+ {%- elif loop.last %}
141
+ {{- '<|im_end|>\n' }}
142
+ {%- endif %}
143
+ {%- else %}
144
+ {{- raise_exception('Unexpected message role.') }}
145
+ {%- endif %}
146
+ {%- endfor %}
147
+ {%- if add_generation_prompt %}
148
+ {{- '<|im_start|>assistant\n' }}
149
+ {%- if enable_thinking is defined and enable_thinking is true %}
150
+ {{- '<think>\n' }}
151
+ {%- else %}
152
+ {{- '<think>\n\n</think>\n\n' }}
153
+ {%- endif %}
154
+ {%- endif %}
tokenizer/tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:06b9509352d2af50381ab2247e083b80d32d5c0aba91c272ca9ff729b6a0e523
3
+ size 19989325
tokenizer/tokenizer_config.json ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "audio_bos_token": "<|audio_start|>",
4
+ "audio_eos_token": "<|audio_end|>",
5
+ "audio_token": "<|audio_pad|>",
6
+ "backend": "tokenizers",
7
+ "bos_token": null,
8
+ "clean_up_tokenization_spaces": false,
9
+ "eos_token": "<|im_end|>",
10
+ "errors": "replace",
11
+ "image_token": "<|image_pad|>",
12
+ "is_local": false,
13
+ "local_files_only": false,
14
+ "model_max_length": 262144,
15
+ "model_specific_special_tokens": {
16
+ "audio_bos_token": "<|audio_start|>",
17
+ "audio_eos_token": "<|audio_end|>",
18
+ "audio_token": "<|audio_pad|>",
19
+ "image_token": "<|image_pad|>",
20
+ "video_token": "<|video_pad|>",
21
+ "vision_bos_token": "<|vision_start|>",
22
+ "vision_eos_token": "<|vision_end|>"
23
+ },
24
+ "pad_token": "<|endoftext|>",
25
+ "pretokenize_regex": "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?[\\p{L}\\p{M}]+|\\p{N}| ?[^\\s\\p{L}\\p{M}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
26
+ "split_special_tokens": false,
27
+ "tokenizer_class": "Qwen2Tokenizer",
28
+ "unk_token": null,
29
+ "video_token": "<|video_pad|>",
30
+ "vision_bos_token": "<|vision_start|>",
31
+ "vision_eos_token": "<|vision_end|>"
32
+ }
train-Copy1.py ADDED
@@ -0,0 +1,924 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import math
3
+ import torch
4
+ import numpy as np
5
+ import matplotlib.pyplot as plt
6
+ import wandb, comet_ml
7
+ import random, time
8
+ import gc
9
+ import bitsandbytes as bnb
10
+ import torch.nn.functional as F
11
+ import argparse
12
+
13
+ from datetime import datetime
14
+ from diffusers import CosmosTransformer3DModel, AutoencoderKLQwenImage, FlowMatchEulerDiscreteScheduler
15
+ from transformers import Qwen3_5Tokenizer, Qwen3_5ForConditionalGeneration
16
+ from torch.utils.data import DataLoader, Sampler
17
+ from torch.optim.lr_scheduler import LambdaLR
18
+ from collections import defaultdict
19
+ from accelerate import Accelerator
20
+ from datasets import load_from_disk
21
+ from tqdm import tqdm
22
+ from PIL import Image, ImageOps
23
+ from torch.utils.checkpoint import checkpoint
24
+ from diffusers.models.attention_processor import AttnProcessor2_0
25
+ from contextlib import nullcontext
26
+ from transformers.optimization import Adafactor
27
+
28
+ # Muon not tested! pip install git+https://github.com/recoilme/muon_adamw8bit.git
29
+ from muon_adamw8bit import MuonAdamW8bit
30
+
31
+ os.environ["NCCL_P2P_DISABLE"] = "1"
32
+ os.environ["NCCL_IB_DISABLE"] = "1" # comment this on H100!
33
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
34
+
35
+ # --------------------------- Параметры ---------------------------
36
+ ds_path = "datasets/ds234_640_vae_qwen"
37
+ project = "transformer"
38
+
39
+ gpu_mem_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
40
+ local_bs = max(1, int((gpu_mem_gb / 32) * 7))
41
+ num_gpus = torch.cuda.device_count()
42
+ batch_size = local_bs * num_gpus
43
+
44
+ base_learning_rate = 4e-5
45
+ min_learning_rate = 4e-6
46
+
47
+ learning_rate_scale = 3
48
+ base_learning_rate = base_learning_rate / learning_rate_scale
49
+ min_learning_rate = min_learning_rate / learning_rate_scale
50
+ print(f"Calculated params max-lr:{base_learning_rate} min-lr:{min_learning_rate} GPUs: {num_gpus}, Global BS: {batch_size}")
51
+
52
+ num_epochs = num_gpus
53
+ sink_interval_share = 10
54
+ sample_interval_min = 20
55
+ cfg_dropout = 0.10
56
+ # Время t, bias = -0.5 (Фокус на Деталях ~300) bias = 0.5 (Фокус на структуре) bias = 0 (колокол/ равномерно)
57
+ sigmoid_bias = 0.1
58
+ max_length = 250
59
+ use_precomputed_embeddings = False
60
+ use_wandb = False
61
+ use_comet_ml = False
62
+ save_model = True
63
+ use_decay = True
64
+ fbp = False
65
+ torch_compile = False
66
+ transformer_gradient = True
67
+ loss_normalize = False
68
+ fixed_seed = False
69
+ shuffle = True
70
+ optimizer_type = "adafactor"
71
+
72
+ if optimizer_type == "muon_adam8bit":
73
+ batch_size = num_gpus * max(1, int((gpu_mem_gb / 32) * 3))
74
+ muon_lr_scale = 500
75
+
76
+ comet_ml_api_key = "Agctp26mbqnoYrrlvQuKSTk6r"
77
+ comet_ml_workspace = "recoilme"
78
+ torch.backends.cuda.matmul.allow_tf32 = True
79
+ torch.backends.cudnn.allow_tf32 = True
80
+ torch.backends.cuda.enable_flash_sdp(True)
81
+ torch.backends.cuda.enable_mem_efficient_sdp(True)
82
+ torch.backends.cuda.enable_math_sdp(False)
83
+ save_barrier = 1.25
84
+ warmup_percent = 0.0025
85
+ betta2 = 0.997
86
+ eps = 1e-6
87
+ clip_grad_norm = 1.0
88
+ limit = 0
89
+ checkpoints_folder = ""
90
+ gradient_accumulation_steps = 1
91
+
92
+ dtype = torch.float32
93
+ mixed_precision = "bf16"
94
+
95
+ # Параметры для диффузии
96
+ n_diffusion_steps = 40
97
+ samples_to_generate = 12
98
+ guidance_scale = 7.0
99
+
100
+ # Папки для сохранения результатов
101
+ generated_folder = "samples"
102
+ os.makedirs(generated_folder, exist_ok=True)
103
+
104
+ # Настройка seed
105
+ current_date = datetime.now()
106
+ seed = int(current_date.strftime("%Y%m%d")) + 42
107
+ if fixed_seed:
108
+ torch.manual_seed(seed)
109
+ np.random.seed(seed)
110
+ random.seed(seed)
111
+ if torch.cuda.is_available():
112
+ torch.cuda.manual_seed_all(seed)
113
+
114
+ accelerator = Accelerator(
115
+ mixed_precision=mixed_precision,
116
+ gradient_accumulation_steps=gradient_accumulation_steps
117
+ )
118
+ device = accelerator.device
119
+
120
+ print("init")
121
+ parser = argparse.ArgumentParser(description='Train a model on a dataset.')
122
+ parser.add_argument('--ds-path', type=str, default=ds_path, help='Path to the dataset')
123
+ parser.add_argument('--ep', type=int, default=num_epochs, help='Number of epochs to train the model')
124
+ parser.add_argument('--batch', type=int, default=batch_size, help='Total batch size')
125
+ parser.add_argument('--min-lr', type=float, default=min_learning_rate, help='Minimum learning rate')
126
+ parser.add_argument('--max-lr', type=float, default=base_learning_rate, help='Maximum learning rate')
127
+ parser.add_argument('--dry-run', action='store_true',default=False, help='Dry run train without saving/sampling')
128
+ parser.add_argument('--lvl', type=float, default=0.0, help='Train level, from 0.5 to 5')
129
+
130
+ args = parser.parse_args()
131
+
132
+ batch_size = args.batch
133
+ ds_path = args.ds_path
134
+ base_learning_rate = args.max_lr
135
+ min_learning_rate = args.min_lr
136
+ num_epochs = args.ep
137
+ lvl = args.lvl
138
+ if args.dry_run:
139
+ save_model = False
140
+ if lvl >= 0.1:
141
+ base_learning_rate = base_learning_rate / lvl
142
+ min_learning_rate = min_learning_rate / lvl
143
+ print(f"max-lr:{base_learning_rate} min-lr:{min_learning_rate}")
144
+
145
+ # --------------------------- Инициализация WandB ---------------------------
146
+ if accelerator.is_main_process:
147
+ if use_wandb:
148
+ wandb.init(project=project, config={
149
+ "batch_size": batch_size,
150
+ "base_learning_rate": base_learning_rate,
151
+ "num_epochs": num_epochs,
152
+ "optimizer_type": optimizer_type,
153
+ })
154
+ if use_comet_ml:
155
+ from comet_ml import Experiment
156
+ comet_experiment = Experiment(
157
+ api_key=comet_ml_api_key,
158
+ project_name=project,
159
+ workspace=comet_ml_workspace
160
+ )
161
+ hyper_params = {
162
+ "batch_size": batch_size,
163
+ "base_learning_rate": base_learning_rate,
164
+ "num_epochs": num_epochs,
165
+ }
166
+ comet_experiment.log_parameters(hyper_params)
167
+
168
+ # --------------------------- Загрузка моделей ---------------------------
169
+ vae = AutoencoderKLQwenImage.from_pretrained("vae", torch_dtype=dtype).to(device).to(dtype=dtype).eval()
170
+ scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained("scheduler")
171
+ tokenizer = None
172
+ text_encoder = None
173
+
174
+ def load_text_encoder():
175
+ global tokenizer, text_encoder
176
+ if tokenizer is None:
177
+ tokenizer = Qwen3_5Tokenizer.from_pretrained("tokenizer")
178
+ if text_encoder is None:
179
+ text_encoder = Qwen3_5ForConditionalGeneration.from_pretrained(
180
+ "text_encoder",
181
+ torch_dtype=dtype
182
+ ).to(device).eval()
183
+
184
+ load_text_encoder()
185
+
186
+ @torch.no_grad()
187
+ def encode_texts(text, max_length=max_length):
188
+ if text is None:
189
+ text = ""
190
+ if isinstance(text, str):
191
+ text = [text]
192
+
193
+ formatted_prompts = []
194
+ for t in text:
195
+ messages = [{"role": "user", "content": [{"type": "text", "text": t}]}]
196
+ formatted_prompts.append(
197
+ tokenizer.apply_chat_template(
198
+ messages,
199
+ tokenize=False,
200
+ add_generation_prompt=False
201
+ )
202
+ )
203
+
204
+ toks = tokenizer(
205
+ formatted_prompts,
206
+ padding="max_length",
207
+ max_length=max_length,
208
+ truncation=True,
209
+ return_tensors="pt"
210
+ ).to(device)
211
+
212
+ outputs = text_encoder(
213
+ input_ids=toks.input_ids,
214
+ attention_mask=toks.attention_mask,
215
+ output_hidden_states=True
216
+ )
217
+
218
+ hidden = outputs.hidden_states[-2].to(dtype=dtype)
219
+
220
+ lengths = toks.attention_mask.sum(dim=1)
221
+ for i, length in enumerate(lengths):
222
+ hidden[i, length:] = 0
223
+
224
+ return hidden, toks.attention_mask.to(dtype=torch.int64)
225
+
226
+ @torch.no_grad()
227
+ def encode_texts_fast(text, max_length=max_length):
228
+ if text is None: text = ""
229
+ if isinstance(text, str): text = [text]
230
+
231
+ formatted_prompts = []
232
+ for t in text:
233
+ messages = [{"role": "user", "content": [{"type": "text", "text": t}]}]
234
+ formatted_prompts.append(tokenizer.apply_chat_template(messages, add_generation_prompt=False, tokenize=False))
235
+
236
+ toks = tokenizer(formatted_prompts, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt").to(device)
237
+ outputs = text_encoder(input_ids=toks.input_ids, attention_mask=toks.attention_mask, output_hidden_states=True)
238
+
239
+ last_hidden = outputs.hidden_states[-2].to(dtype=dtype)
240
+
241
+ lengths = toks.attention_mask.sum(dim=1)
242
+ for i, length in enumerate(lengths):
243
+ last_hidden[i, length:] = 0
244
+
245
+ return last_hidden, toks.attention_mask.to(dtype=torch.int64)
246
+
247
+ shift_factor = getattr(vae.config, "shift_factor", 0.0)
248
+ if shift_factor is None:
249
+ shift_factor = 0.0
250
+
251
+ scaling_factor = getattr(vae.config, "scaling_factor", 1.0)
252
+ if scaling_factor is None:
253
+ scaling_factor = 1.0
254
+
255
+ mean = getattr(vae.config, "latents_mean", None)
256
+ std = getattr(vae.config, "latents_std", None)
257
+ if mean is not None and std is not None:
258
+ latents_std = torch.tensor(std, device=device, dtype=dtype).view(1, len(std), 1, 1)
259
+ latents_mean = torch.tensor(mean, device=device, dtype=dtype).view(1, len(mean), 1, 1)
260
+ # Внимание: Cosmos использует инвертированный std для декодирования (1.0 / std)
261
+ #latents_std = 1.0 / torch.tensor(std).view(1, len(std), 1, 1, 1)
262
+ else:
263
+ latents_std = None
264
+ latents_mean = None
265
+
266
+ if scheduler is not None:
267
+ scheduler.register_to_config(
268
+ sigma_max=getattr(scheduler.config, "sigma_max", 80.0),
269
+ sigma_min=getattr(scheduler.config, "sigma_min", 0.002),
270
+ sigma_data=getattr(scheduler.config, "sigma_data", 1.0),
271
+ final_sigmas_type=getattr(scheduler.config, "final_sigmas_type", "sigma_min"),
272
+ )
273
+
274
+ import numpy as np
275
+ from torch.utils.data import Sampler
276
+
277
+ class DistributedResolutionBatchSampler(Sampler):
278
+ def __init__(self, dataset, batch_size, num_replicas, rank, drop_last=True, shuffle=True):
279
+ self.dataset = dataset
280
+ self.num_replicas = num_replicas
281
+ self.rank = rank
282
+ self.shuffle = shuffle
283
+ self.drop_last = drop_last
284
+ self.epoch = 0
285
+
286
+ self.batch_size = max(1, batch_size // num_replicas)
287
+ self.global_batch = self.batch_size * num_replicas
288
+
289
+ try:
290
+ widths = np.asarray(dataset["width"])
291
+ heights = np.asarray(dataset["height"])
292
+ except KeyError:
293
+ widths = np.zeros(len(dataset))
294
+ heights = np.zeros(len(dataset))
295
+
296
+ groups = {}
297
+ for i, (w, h) in enumerate(zip(widths, heights)):
298
+ groups.setdefault((w, h), []).append(i)
299
+
300
+ all_batches = []
301
+ for indices in groups.values():
302
+ idx = np.asarray(indices, dtype=np.int64)
303
+ num_batches = len(idx) // self.global_batch
304
+ if num_batches == 0:
305
+ continue
306
+ idx = idx[: num_batches * self.global_batch]
307
+ batches = idx.reshape(num_batches, self.global_batch)
308
+ all_batches.append(batches)
309
+
310
+ if len(all_batches) > 0:
311
+ self.global_batches = np.concatenate(all_batches, axis=0)
312
+ else:
313
+ self.global_batches = np.empty((0, self.global_batch), dtype=np.int64)
314
+
315
+ self.num_batches = len(self.global_batches)
316
+
317
+ def __iter__(self):
318
+ rng = np.random.RandomState(self.epoch)
319
+ order = np.arange(self.num_batches)
320
+
321
+ if self.shuffle:
322
+ rng.shuffle(order)
323
+
324
+ start = self.rank * self.batch_size
325
+ end = start + self.batch_size
326
+
327
+ for i in order:
328
+ yield self.global_batches[i][start:end]
329
+
330
+ def __len__(self):
331
+ return self.num_batches
332
+
333
+ def set_epoch(self, epoch):
334
+ self.epoch = epoch
335
+
336
+ def get_fixed_samples_by_resolution(dataset, samples_per_group=1):
337
+ size_groups = defaultdict(list)
338
+ try:
339
+ widths = dataset["width"]
340
+ heights = dataset["height"]
341
+ except KeyError:
342
+ widths = [0] * len(dataset)
343
+ heights = [0] * len(dataset)
344
+ for i, (w, h) in enumerate(zip(widths, heights)):
345
+ size = (w, h)
346
+ size_groups[size].append(i)
347
+
348
+ fixed_samples = {}
349
+ for size, indices in size_groups.items():
350
+ n_samples = min(samples_per_group, len(indices))
351
+ if len(size_groups)==1:
352
+ n_samples = samples_to_generate
353
+ if n_samples == 0:
354
+ continue
355
+ sample_indices = random.sample(indices, n_samples)
356
+ samples_data = [dataset[idx] for idx in sample_indices]
357
+
358
+ latents = torch.tensor(np.array([item["vae"] for item in samples_data])).to(device=device, dtype=dtype)
359
+
360
+ if latents.ndim == 4:
361
+ latents = latents.unsqueeze(2)
362
+ elif latents.ndim == 6:
363
+ latents = latents.squeeze(2)
364
+
365
+ texts = [item["text"] for item in samples_data]
366
+
367
+ if use_precomputed_embeddings:
368
+ embeddings = torch.tensor(
369
+ np.array([item["embeddings"] for item in samples_data]),
370
+ device=device,
371
+ dtype=dtype
372
+ )
373
+ masks = torch.tensor(
374
+ np.array([item["attention_mask"] for item in samples_data]),
375
+ device=device,
376
+ dtype=torch.int64
377
+ )
378
+ else:
379
+ embeddings, masks = encode_texts(texts,max_length)
380
+
381
+ fixed_samples[size] = (latents, embeddings, masks, texts)
382
+
383
+ print(f"Создано {len(fixed_samples)} групп фиксированных семплов по разрешениям")
384
+ return fixed_samples
385
+
386
+ if limit > 0:
387
+ dataset = load_from_disk(ds_path).select(range(limit))
388
+ else:
389
+ dataset = load_from_disk(ds_path)
390
+
391
+ print(f"images: {len(dataset)}")
392
+
393
+ def collate_fn_simple(batch):
394
+ latents = torch.from_numpy(
395
+ np.array([item["vae"] for item in batch], dtype=np.float16)
396
+ ).to(device, dtype=dtype)
397
+
398
+ if latents.ndim == 4:
399
+ latents = latents.unsqueeze(2)
400
+ elif latents.ndim == 6:
401
+ latents = latents.squeeze(2)
402
+
403
+ if use_precomputed_embeddings:
404
+ embeddings = torch.from_numpy(
405
+ np.array([item["embeddings"] for item in batch], dtype=np.float16)
406
+ ).to(device, dtype=dtype)
407
+
408
+ attention_mask = torch.from_numpy(
409
+ np.array([item["attention_mask"] for item in batch], dtype=np.int64)
410
+ ).to(device)
411
+
412
+ return latents, embeddings, attention_mask
413
+
414
+ raw_texts = [item["text"] for item in batch]
415
+
416
+ texts = [
417
+ "" if t.lower().startswith("zero")
418
+ else "" if random.random() < cfg_dropout
419
+ else t[1:].lstrip() if t.startswith(".")
420
+ else t.replace("The image shows ", "").replace("The image is ", "").replace("This image captures ","").strip()
421
+ for t in raw_texts
422
+ ]
423
+
424
+ embeddings, attention_mask = encode_texts(texts,max_length)
425
+ attention_mask = attention_mask.to(dtype=torch.int64)
426
+
427
+ return latents, embeddings, attention_mask
428
+
429
+ batch_sampler = DistributedResolutionBatchSampler(
430
+ dataset=dataset,
431
+ batch_size=batch_size,
432
+ num_replicas=accelerator.num_processes,
433
+ rank=accelerator.process_index,
434
+ shuffle = shuffle
435
+ )
436
+
437
+ dataloader = DataLoader(dataset, batch_sampler=batch_sampler, collate_fn=collate_fn_simple)
438
+
439
+ if accelerator.is_main_process:
440
+ print("Total samples", len(dataloader))
441
+ dataloader = accelerator.prepare(dataloader)
442
+
443
+ start_epoch = 0
444
+ global_step = 0
445
+ total_training_steps = (len(dataloader) * num_epochs)
446
+ world_size = accelerator.state.num_processes
447
+
448
+ latest_checkpoint = os.path.join(checkpoints_folder, project)
449
+ if os.path.isdir(latest_checkpoint):
450
+ print("Загружаем Transformer из чекпоинта:", latest_checkpoint)
451
+ transformer = CosmosTransformer3DModel.from_pretrained(latest_checkpoint).to(device=device, dtype=dtype)
452
+ if transformer_gradient:
453
+ transformer.enable_gradient_checkpointing()
454
+ else:
455
+ raise FileNotFoundError(f"Transformer checkpoint not found at {latest_checkpoint}")
456
+
457
+ def create_optimizer(name, params):
458
+ if name == "adam8bit":
459
+ return bnb.optim.AdamW8bit(
460
+ params, lr=base_learning_rate, betas=(0.9, betta2), eps=eps, weight_decay=0.001
461
+ )
462
+ elif name == "adam":
463
+ return torch.optim.AdamW(
464
+ params, lr=base_learning_rate, betas=(0.9, betta2), eps=eps, weight_decay=0.001
465
+ )
466
+ elif name == "adafactor":
467
+ return Adafactor(
468
+ params,
469
+ lr=base_learning_rate,
470
+ eps=(1e-30, 1e-3),
471
+ clip_threshold=1.0,
472
+ decay_rate=-0.8,
473
+ beta1=None,
474
+ weight_decay=0.001,
475
+ relative_step=False,
476
+ scale_parameter=False,
477
+ warmup_init=False
478
+ )
479
+ elif name == "muon_adam8bit":
480
+ return MuonAdamW8bit(
481
+ params,
482
+ lr=base_learning_rate,
483
+ betas=(0.9, betta2),
484
+ eps=eps,
485
+ weight_decay=0.01,
486
+ muon_lr_mult=muon_lr_scale,
487
+ )
488
+ else:
489
+ raise ValueError(f"Unknown optimizer: {name}")
490
+
491
+ if fbp:
492
+ trainable_params = list(transformer.parameters())
493
+ optimizer_dict = {p: create_optimizer(optimizer_type, [p]) for p in trainable_params}
494
+ def optimizer_hook(param):
495
+ optimizer_dict[param].step()
496
+ optimizer_dict[param].zero_grad(set_to_none=True)
497
+ for param in trainable_params:
498
+ param.register_post_accumulate_grad_hook(optimizer_hook)
499
+ transformer, optimizer = accelerator.prepare(transformer, optimizer_dict)
500
+ else:
501
+ #transformer.requires_grad_(True)
502
+ # 1. Сначала замораживаем ВООБЩЕ ВСЕ параметры
503
+ transformer.requires_grad_(False)
504
+
505
+ # 2. Определяем ключевое слово для слоев, которые нужно учить (Cross-Attention)
506
+ trainable_params_names = ["attn2"]
507
+ trainable_params = []
508
+
509
+ print("--- РАЗМОРОЖЕННЫЕ СЛОИ ---")
510
+ for name, param in transformer.named_parameters():
511
+ if any(target in name for target in trainable_params_names):
512
+ param.requires_grad_(True) # Размораживаем
513
+ trainable_params.append(param)
514
+ print(f"Обучаемый слой: {name}")
515
+ print("--------------------------")
516
+
517
+ # Защита от дурака
518
+ if len(trainable_params) == 0:
519
+ raise ValueError("Ошибка: ни один слой не был разморожен! Проверь ключи.")
520
+
521
+ optimizer = create_optimizer(optimizer_type, transformer.parameters())
522
+
523
+ def lr_schedule(step):
524
+ x = step / (total_training_steps * world_size)
525
+ warmup = warmup_percent
526
+ if not use_decay:
527
+ return base_learning_rate
528
+ if x < warmup:
529
+ return min_learning_rate + (base_learning_rate - min_learning_rate) * (x / warmup)
530
+ decay_ratio = (x - warmup) / (1 - warmup)
531
+ return min_learning_rate + 0.5 * (base_learning_rate - min_learning_rate) * \
532
+ (1 + math.cos(math.pi * decay_ratio))
533
+ lr_scheduler = LambdaLR(optimizer, lambda step: lr_schedule(step) / base_learning_rate)
534
+
535
+ if torch_compile:
536
+ print("Compiling Transformer... Это займет несколько минут, не прерывайте!")
537
+ transformer = torch.compile(transformer)
538
+ print("Compiling - ok")
539
+
540
+ if not fbp:
541
+ transformer, optimizer, lr_scheduler = accelerator.prepare(transformer, optimizer, lr_scheduler)
542
+
543
+ # Фиксированные семплы
544
+ fixed_samples = get_fixed_samples_by_resolution(dataset)
545
+
546
+ def get_negative_embedding(neg_prompt="", batch_size=1):
547
+ if not neg_prompt:
548
+ hidden_dim = 2048
549
+ seq_len = max_length
550
+ empty_emb = torch.zeros((batch_size, seq_len, hidden_dim), dtype=dtype, device=device)
551
+ empty_mask = torch.ones((batch_size, seq_len), dtype=torch.int64, device=device)
552
+ return empty_emb, empty_mask
553
+
554
+ uncond_emb, uncond_mask = encode_texts([neg_prompt],max_length)
555
+ uncond_emb = uncond_emb.to(dtype=dtype, device=device).repeat(batch_size, 1, 1)
556
+ uncond_mask = uncond_mask.to(device=device).repeat(batch_size, 1)
557
+
558
+ return uncond_emb, uncond_mask
559
+
560
+ if use_precomputed_embeddings:
561
+ load_text_encoder()
562
+ uncond_emb, uncond_mask = get_negative_embedding("low quality")
563
+ uncond_emb = uncond_emb.to("cpu")
564
+ uncond_mask = uncond_mask.to("cpu")
565
+ del text_encoder
566
+ torch.cuda.empty_cache()
567
+ gc.collect()
568
+ text_encoder = None
569
+ else:
570
+ uncond_emb, uncond_mask = get_negative_embedding("low quality")
571
+
572
+ def pad_to_match(a, b, pad_value=0):
573
+ Ta, Tb = a.shape[1], b.shape[1]
574
+ if Ta == Tb:
575
+ return a, b
576
+ T = max(Ta, Tb)
577
+ def pad(x, T_target):
578
+ pad_len = T_target - x.shape[1]
579
+ if pad_len <= 0:
580
+ return x
581
+ return torch.nn.functional.pad(x, (0, 0, 0, pad_len), value=pad_value)
582
+ return pad(a, T), pad(b, T)
583
+
584
+ @torch.compiler.disable()
585
+ @torch.no_grad()
586
+ def generate_and_save_samples(fixed_samples_cpu, uncond_data, step):
587
+ uncond_emb, uncond_mask = uncond_data
588
+ uncond_emb = uncond_emb.to(device)
589
+ uncond_mask = uncond_mask.to(device)
590
+
591
+ original_model = None
592
+ try:
593
+ if not torch_compile:
594
+ original_model = accelerator.unwrap_model(transformer, keep_torch_compile=True).eval()
595
+ else:
596
+ original_model = transformer.eval()
597
+
598
+ vae.to(device=device).eval()
599
+
600
+ all_generated_images = []
601
+ all_captions = []
602
+
603
+ for size, (sample_latents, sample_text_embeddings, sample_mask, sample_text) in fixed_samples_cpu.items():
604
+ width, height = size
605
+
606
+ curr_batch_size = sample_latents.shape[0]
607
+ in_channels = original_model.config.in_channels
608
+
609
+ sample_text_embeddings = sample_text_embeddings.to(dtype=dtype, device=device)
610
+
611
+ sigmas_dtype = torch.float32
612
+ sigmas = torch.linspace(0, 1, n_diffusion_steps, dtype=sigmas_dtype)
613
+ scheduler.set_timesteps(sigmas=sigmas, device=device)
614
+
615
+ if scheduler.config.get("final_sigmas_type", "zero") == "sigma_min":
616
+ scheduler.sigmas[-1] = scheduler.sigmas[-2]
617
+ if scheduler.sigmas[-1] == 0.0:
618
+ scheduler.sigmas[-1] = 1e-4
619
+
620
+ sigma_max = getattr(scheduler.config, "sigma_max", 80.0)
621
+
622
+ latents = torch.randn(
623
+ (curr_batch_size, in_channels, 1, sample_latents.shape[3], sample_latents.shape[4]),
624
+ device=device,
625
+ dtype=dtype,
626
+ generator=torch.Generator(device=device).manual_seed(seed)
627
+ ) * sigma_max
628
+
629
+ padding_mask = torch.zeros((1, 1, sample_latents.shape[3], sample_latents.shape[4]), device=device, dtype=dtype)
630
+
631
+ if guidance_scale != 1:
632
+ neg_emb_batch = uncond_emb[0:1].expand(curr_batch_size, -1, -1)
633
+ neg_emb_batch, sample_text_embeddings = pad_to_match(neg_emb_batch, sample_text_embeddings)
634
+
635
+ for i, t in enumerate(scheduler.timesteps):
636
+ current_sigma = scheduler.sigmas[i]
637
+ if current_sigma == 0.0:
638
+ current_sigma = torch.tensor(1e-4, dtype=current_sigma.dtype, device=device)
639
+
640
+ current_t = current_sigma / (current_sigma + 1.0)
641
+ c_in = 1.0 - current_t
642
+ c_skip = 1.0 - current_t
643
+ c_out = -current_t
644
+
645
+ latent_model_input = (latents * c_in).to(dtype)
646
+
647
+ t_val = float(current_t.item()) if torch.is_tensor(current_t) else float(current_t)
648
+ timestep_tensor = torch.tensor([t_val], device=device, dtype=dtype).expand(curr_batch_size)
649
+
650
+ noise_pred = original_model(
651
+ hidden_states=latent_model_input,
652
+ timestep=timestep_tensor,
653
+ encoder_hidden_states=sample_text_embeddings,
654
+ padding_mask=padding_mask,
655
+ return_dict=False
656
+ )[0]
657
+
658
+ noise_pred = (c_skip * latents + c_out * noise_pred.float()).to(dtype)
659
+
660
+ if guidance_scale != 1:
661
+ noise_pred_uncond = original_model(
662
+ hidden_states=latent_model_input,
663
+ timestep=timestep_tensor,
664
+ encoder_hidden_states=neg_emb_batch,
665
+ padding_mask=padding_mask,
666
+ return_dict=False
667
+ )[0]
668
+ noise_pred_uncond = (c_skip * latents + c_out * noise_pred_uncond.float()).to(dtype)
669
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred - noise_pred_uncond)
670
+
671
+ noise_pred = (latents - noise_pred) / current_sigma
672
+ latents = scheduler.step(noise_pred, t, latents, return_dict=False)[0]
673
+
674
+ current_latents = latents
675
+ if step == 0:
676
+ current_latents = sample_latents
677
+
678
+ if latents_mean is not None and latents_std is not None:
679
+ sigma_data = getattr(scheduler.config, "sigma_data", 1.0)
680
+ # Переводим векторы нормализации в float32
681
+ l_mean = torch.tensor(vae.config.latents_mean).view(1, -1, 1, 1, 1).to(device, torch.float32)
682
+ l_std = torch.tensor(vae.config.latents_std).view(1, -1, 1, 1, 1).to(device, torch.float32)
683
+
684
+ # Кастуем латенты в float32 перед умножением, чтобы сохранить точность
685
+ latents_for_decode = (current_latents.to(torch.float32) * l_std) / sigma_data + l_mean
686
+ else:
687
+ latents_for_decode = current_latents.to(torch.float32)
688
+
689
+ # 2. Декодируем, ПРИНУДИТЕЛЬНО ВКЛЮЧИВ MATH_SDP только для этого шага!
690
+ with torch.backends.cuda.sdp_kernel(enable_math=True, enable_flash=False, enable_mem_efficient=False):
691
+ decoded = vae.decode(latents_for_decode).sample
692
+
693
+ # 3. Отсекаем лишнее видео-измерение
694
+ if decoded.ndim == 5:
695
+ decoded = decoded[:, :, 0, :, :]
696
+
697
+ # 4. Он уже во float32, можно сразу пускать в цикл
698
+ decoded_fp32 = decoded
699
+
700
+
701
+ for img_idx, img_tensor in enumerate(decoded_fp32):
702
+ img = (img_tensor / 2 + 0.5).clamp(0, 1).cpu().numpy()
703
+ img = img.transpose(1, 2, 0)
704
+
705
+ if np.isnan(img).any():
706
+ print("NaNs found, saving stopped! Step:", step)
707
+ img = np.nan_to_num(img, nan=0.0)
708
+ pil_img = Image.fromarray((img * 255).astype("uint8"))
709
+
710
+ max_w_overall = max(s[0] for s in fixed_samples_cpu.keys())
711
+ max_h_overall = max(s[1] for s in fixed_samples_cpu.keys())
712
+ max_w_overall = max(255, max_w_overall)
713
+ max_h_overall = max(255, max_h_overall)
714
+
715
+ padded_img = ImageOps.pad(pil_img, (max_w_overall, max_h_overall), color='white')
716
+ all_generated_images.append(padded_img)
717
+
718
+ caption_text = sample_text[img_idx][:300] if img_idx < len(sample_text) else ""
719
+ all_captions.append(caption_text)
720
+
721
+ sample_path = f"{generated_folder}/{project}_{width}x{height}_{img_idx}.jpg"
722
+ pil_img.save(sample_path, "JPEG", quality=95)
723
+
724
+ if use_wandb and accelerator.is_main_process:
725
+ wandb_images = [
726
+ wandb.Image(img, caption=f"{all_captions[i]}")
727
+ for i, img in enumerate(all_generated_images)
728
+ ]
729
+ wandb.log({"generated_images": wandb_images})
730
+ if use_comet_ml and accelerator.is_main_process:
731
+ for i, img in enumerate(all_generated_images):
732
+ comet_experiment.log_image(
733
+ image_data=img,
734
+ name=f"step_{step}_img_{i}",
735
+ step=step,
736
+ metadata={"caption": all_captions[i]}
737
+ )
738
+ finally:
739
+ vae.to("cpu")
740
+ uncond_emb = uncond_emb.to("cpu")
741
+ uncond_mask = uncond_mask.to("cpu")
742
+ try:
743
+ all_generated_images.clear()
744
+ all_captions.clear()
745
+ del all_generated_images, all_captions
746
+ del latents, current_latents, latent_model_input
747
+ del decoded, decoded_fp32
748
+ del sample_latents, sample_text_embeddings, sample_mask
749
+ del noise_pred, noise_pred_uncond
750
+ except UnboundLocalError:
751
+ pass
752
+
753
+ torch.cuda.synchronize()
754
+ torch.cuda.empty_cache()
755
+ gc.collect()
756
+
757
+ if accelerator.is_main_process:
758
+ if save_model:
759
+ print("Генерация сэмплов до старта обучения...")
760
+ generate_and_save_samples(fixed_samples, (uncond_emb, uncond_mask), 0)
761
+ accelerator.wait_for_everyone()
762
+
763
+ def save_checkpoint(model_net, variant=""):
764
+ if accelerator.is_main_process:
765
+ model_to_save = None
766
+ if not torch_compile:
767
+ model_to_save = accelerator.unwrap_model(model_net)
768
+ else:
769
+ model_to_save = model_net
770
+
771
+ if variant != "":
772
+ model_to_save.to(dtype=torch.bfloat16).save_pretrained(
773
+ os.path.join(checkpoints_folder, f"{project}"), variant=variant
774
+ )
775
+ else:
776
+ model_to_save.save_pretrained(os.path.join(checkpoints_folder, f"{project}"))
777
+
778
+ torch.cuda.synchronize()
779
+ torch.cuda.empty_cache()
780
+ gc.collect()
781
+
782
+ if accelerator.is_main_process:
783
+ print(f"Total steps per GPU: {total_training_steps}")
784
+
785
+ epoch_loss_points = []
786
+ progress_bar = tqdm(total=total_training_steps, disable=not accelerator.is_local_main_process, desc="Training", unit="step")
787
+
788
+ steps_per_epoch = len(dataloader)
789
+ sink_interval = max(1, steps_per_epoch // sink_interval_share)
790
+ min_loss = 4.
791
+ last_sample_time = time.time()
792
+ sample_interval_seconds = sample_interval_min * 60
793
+
794
+ for epoch in range(start_epoch, start_epoch + num_epochs):
795
+ batch_losses = []
796
+ batch_grads = []
797
+ batch_sampler.set_epoch(epoch)
798
+ accelerator.wait_for_everyone()
799
+ transformer.train()
800
+
801
+ for step, (latents, embeddings, attention_mask) in enumerate(dataloader):
802
+
803
+ if save_model == False and epoch == 0 and step == 5 :
804
+ used_gb = torch.cuda.max_memory_allocated() / 1024**3
805
+ print(f"Шаг {step}: {used_gb:.2f} GB")
806
+
807
+ amp_context = accelerator.autocast() if torch_compile else nullcontext()
808
+ with accelerator.accumulate(transformer):
809
+ with amp_context:
810
+ noise = torch.randn_like(latents, dtype=latents.dtype)
811
+
812
+ t = torch.sigmoid(torch.randn(latents.shape[0], device=latents.device, dtype=latents.dtype) + sigmoid_bias)
813
+
814
+ noisy_latents_5d = (1.0 - t.view(-1, 1, 1, 1, 1)) * latents + t.view(-1, 1, 1, 1, 1) * noise
815
+ target_5d = noise - latents
816
+
817
+ padding_mask = torch.zeros((1, 1, latents.shape[3], latents.shape[4]), device=device, dtype=dtype)
818
+
819
+ timestep_tensor = t.flatten().to(dtype)
820
+
821
+ model_pred = transformer(
822
+ hidden_states=noisy_latents_5d,
823
+ timestep=timestep_tensor,
824
+ encoder_hidden_states=embeddings,
825
+ padding_mask=padding_mask,
826
+ return_dict=False
827
+ )[0]
828
+
829
+ mse_loss = F.mse_loss(model_pred.float(), target_5d.float())
830
+ batch_losses.append(mse_loss.detach().item())
831
+
832
+ if (global_step % 100 == 0) or (global_step % sink_interval == 0):
833
+ accelerator.wait_for_everyone()
834
+
835
+ losses_dict = {}
836
+ losses_dict["mse"] = mse_loss
837
+
838
+ if (global_step % 100 == 0) or (global_step % sink_interval == 0):
839
+ accelerator.wait_for_everyone()
840
+
841
+ accelerator.backward(mse_loss)
842
+
843
+ if (global_step % 100 == 0) or (global_step % sink_interval == 0):
844
+ accelerator.wait_for_everyone()
845
+
846
+ grad = 0.0
847
+ if not fbp:
848
+ if accelerator.sync_gradients:
849
+ grad_val = accelerator.clip_grad_norm_(transformer.parameters(), clip_grad_norm)
850
+ grad = grad_val.float().item() if torch.is_tensor(grad_val) else float(grad_val)
851
+ optimizer.step()
852
+ lr_scheduler.step()
853
+ optimizer.zero_grad(set_to_none=True)
854
+
855
+ if accelerator.sync_gradients:
856
+ global_step += 1
857
+ progress_bar.update(1)
858
+ if accelerator.is_main_process:
859
+ if fbp:
860
+ current_lr = base_learning_rate
861
+ else:
862
+ current_lr = lr_scheduler.get_last_lr()[0]
863
+ batch_grads.append(grad)
864
+
865
+ log_data = {}
866
+ log_data["loss_mse"] = mse_loss.detach().item()
867
+ log_data["lr"] = current_lr
868
+ log_data["grad"] = grad
869
+ if accelerator.sync_gradients:
870
+ if use_wandb:
871
+ wandb.log(log_data, step=global_step)
872
+ if use_comet_ml:
873
+ comet_experiment.log_metrics(log_data, step=global_step)
874
+
875
+ current_time = time.time()
876
+ is_time_to_sample = (current_time - last_sample_time) >= sample_interval_seconds
877
+ if is_time_to_sample or global_step == 50:
878
+ if save_model:
879
+ generate_and_save_samples(fixed_samples, (uncond_emb, uncond_mask), global_step)
880
+ elif epoch % 10 == 0:
881
+ generate_and_save_samples(fixed_samples, (uncond_emb, uncond_mask), global_step)
882
+ last_n = sink_interval
883
+
884
+ if save_model:
885
+ has_losses = len(batch_losses) > 0
886
+ avg_sample_loss = np.mean(batch_losses[-sink_interval:]) if has_losses else 0.0
887
+ last_loss = batch_losses[-1] if has_losses else 0.0
888
+ max_loss = max(avg_sample_loss, last_loss)
889
+ should_save = max_loss < min_loss * save_barrier
890
+ print(
891
+ f"Saving: {should_save} | Max: {max_loss:.4f} | "
892
+ f"Last: {last_loss:.4f} | Avg: {avg_sample_loss:.4f}"
893
+ )
894
+ if should_save:
895
+ min_loss = max_loss
896
+ save_checkpoint(transformer)
897
+ last_sample_time = current_time
898
+ transformer.train()
899
+
900
+ if accelerator.is_main_process:
901
+ avg_epoch_loss = np.mean(batch_losses) if len(batch_losses) > 0 else 0.0
902
+ avg_epoch_grad = np.mean(batch_grads) if len(batch_grads) > 0 else 0.0
903
+
904
+ print(f"\nЭпоха {epoch} завершена. Средний лосс: {avg_epoch_loss:.6f}")
905
+ log_data_ep = {
906
+ "epoch_loss": avg_epoch_loss,
907
+ "epoch_grad": avg_epoch_grad,
908
+ "epoch": epoch + 1,
909
+ }
910
+ if use_wandb:
911
+ wandb.log(log_data_ep)
912
+ if use_comet_ml:
913
+ comet_experiment.log_metrics(log_data_ep)
914
+
915
+ if accelerator.is_main_process:
916
+ print("Обучение завершено! Сохраняем финальную модель...")
917
+ save_checkpoint(transformer,"bf16")
918
+ if use_comet_ml:
919
+ comet_experiment.end()
920
+ accelerator.free_memory()
921
+ if torch.distributed.is_initialized():
922
+ torch.distributed.destroy_process_group()
923
+
924
+ print("Готово!")
transformer/config.json ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "CosmosTransformer3DModel",
3
+ "_diffusers_version": "0.37.1",
4
+ "_name_or_path": "transformer",
5
+ "adaln_lora_dim": 256,
6
+ "attention_head_dim": 128,
7
+ "concat_padding_mask": true,
8
+ "controlnet_block_every_n": null,
9
+ "crossattn_proj_in_channels": 1024,
10
+ "encoder_hidden_states_channels": 1024,
11
+ "extra_pos_embed_type": null,
12
+ "img_context_dim_in": null,
13
+ "img_context_dim_out": 2048,
14
+ "img_context_num_tokens": 256,
15
+ "in_channels": 16,
16
+ "max_size": [
17
+ 128,
18
+ 240,
19
+ 240
20
+ ],
21
+ "mlp_ratio": 4.0,
22
+ "num_attention_heads": 16,
23
+ "num_layers": 28,
24
+ "out_channels": 16,
25
+ "patch_size": [
26
+ 1,
27
+ 2,
28
+ 2
29
+ ],
30
+ "rope_scale": [
31
+ 1.0,
32
+ 4.0,
33
+ 4.0
34
+ ],
35
+ "text_embed_dim": 1024,
36
+ "use_crossattn_projection": false
37
+ }
transformer/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:501d3b67f235189364d1bbeb3862fcdfc74e957f033e0714e8c2a12ba95a7041
3
+ size 7825687184
vae/.ipynb_checkpoints/config-checkpoint.json ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "AutoencoderKLQwenImage",
3
+ "_diffusers_version": "0.36.0.dev0",
4
+ "attn_scales": [],
5
+ "base_dim": 96,
6
+ "dim_mult": [
7
+ 1,
8
+ 2,
9
+ 4,
10
+ 4
11
+ ],
12
+ "dropout": 0.0,
13
+ "latents_mean": [
14
+ -0.7571,
15
+ -0.7089,
16
+ -0.9113,
17
+ 0.1075,
18
+ -0.1745,
19
+ 0.9653,
20
+ -0.1517,
21
+ 1.5508,
22
+ 0.4134,
23
+ -0.0715,
24
+ 0.5517,
25
+ -0.3632,
26
+ -0.1922,
27
+ -0.9497,
28
+ 0.2503,
29
+ -0.2921
30
+ ],
31
+ "latents_std": [
32
+ 2.8184,
33
+ 1.4541,
34
+ 2.3275,
35
+ 2.6558,
36
+ 1.2196,
37
+ 1.7708,
38
+ 2.6052,
39
+ 2.0743,
40
+ 3.2687,
41
+ 2.1526,
42
+ 2.8652,
43
+ 1.5579,
44
+ 1.6382,
45
+ 1.1253,
46
+ 2.8251,
47
+ 1.916
48
+ ],
49
+ "num_res_blocks": 2,
50
+ "temperal_downsample": [
51
+ false,
52
+ true,
53
+ true
54
+ ],
55
+ "z_dim": 16
56
+ }
vae/config.json ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "AutoencoderKLQwenImage",
3
+ "_diffusers_version": "0.36.0.dev0",
4
+ "attn_scales": [],
5
+ "base_dim": 96,
6
+ "dim_mult": [
7
+ 1,
8
+ 2,
9
+ 4,
10
+ 4
11
+ ],
12
+ "dropout": 0.0,
13
+ "latents_mean": [
14
+ -0.7571,
15
+ -0.7089,
16
+ -0.9113,
17
+ 0.1075,
18
+ -0.1745,
19
+ 0.9653,
20
+ -0.1517,
21
+ 1.5508,
22
+ 0.4134,
23
+ -0.0715,
24
+ 0.5517,
25
+ -0.3632,
26
+ -0.1922,
27
+ -0.9497,
28
+ 0.2503,
29
+ -0.2921
30
+ ],
31
+ "latents_std": [
32
+ 2.8184,
33
+ 1.4541,
34
+ 2.3275,
35
+ 2.6558,
36
+ 1.2196,
37
+ 1.7708,
38
+ 2.6052,
39
+ 2.0743,
40
+ 3.2687,
41
+ 2.1526,
42
+ 2.8652,
43
+ 1.5579,
44
+ 1.6382,
45
+ 1.1253,
46
+ 2.8251,
47
+ 1.916
48
+ ],
49
+ "num_res_blocks": 2,
50
+ "temperal_downsample": [
51
+ false,
52
+ true,
53
+ true
54
+ ],
55
+ "z_dim": 16
56
+ }
vae/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0c8bc8b758c649abef9ea407b95408389a3b2f610d0d10fcb054fe171d0a8344
3
+ size 253806966
wandb/debug-cli.root.log ADDED
File without changes
wandb/debug-internal.log ADDED
The diff for this file is too large to render. See raw diff
 
wandb/debug.log ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 2026-04-28 17:16:45,138 INFO MainThread:14112 [wandb_setup.py:_flush():81] Current SDK version is 0.26.1
2
+ 2026-04-28 17:16:45,138 INFO MainThread:14112 [wandb_setup.py:_flush():81] Configure stats pid to 14112
3
+ 2026-04-28 17:16:45,138 INFO MainThread:14112 [wandb_setup.py:_flush():81] Loading settings from environment variables
4
+ 2026-04-28 17:16:45,138 INFO MainThread:14112 [wandb_init.py:setup_run_log_directory():723] Logging user logs to /root/sdxs-2b/wandb/run-20260428_171645-wt40fdyx/logs/debug.log
5
+ 2026-04-28 17:16:45,138 INFO MainThread:14112 [wandb_init.py:setup_run_log_directory():724] Logging internal logs to /root/sdxs-2b/wandb/run-20260428_171645-wt40fdyx/logs/debug-internal.log
6
+ 2026-04-28 17:16:45,138 INFO MainThread:14112 [wandb_init.py:init():850] calling init triggers
7
+ 2026-04-28 17:16:45,138 INFO MainThread:14112 [wandb_init.py:init():855] wandb.init called with sweep_config: {}
8
+ config: {'batch_size': 16, 'base_learning_rate': 1.3333333333333335e-05, 'num_epochs': 1, 'optimizer_type': 'adafactor', '_wandb': {}}
9
+ 2026-04-28 17:16:45,138 INFO MainThread:14112 [wandb_init.py:init():898] starting backend
10
+ 2026-04-28 17:16:45,343 INFO MainThread:14112 [wandb_init.py:init():913] sending inform_init request
11
+ 2026-04-28 17:16:45,731 INFO MainThread:14112 [wandb_init.py:init():918] backend started and connected
12
+ 2026-04-28 17:16:45,734 INFO MainThread:14112 [wandb_init.py:init():988] updated telemetry
13
+ 2026-04-28 17:16:45,742 INFO MainThread:14112 [wandb_init.py:init():1011] communicating run to backend with 90.0 second timeout
14
+ 2026-04-28 17:16:46,973 INFO MainThread:14112 [wandb_init.py:init():1056] starting run threads in backend
15
+ 2026-04-28 17:16:47,099 INFO MainThread:14112 [wandb_run.py:_console_start():2554] atexit reg
16
+ 2026-04-28 17:16:47,099 INFO MainThread:14112 [wandb_run.py:_redirect():2403] redirect: wrap_raw
17
+ 2026-04-28 17:16:47,100 INFO MainThread:14112 [wandb_run.py:_redirect():2472] Wrapping output streams.
18
+ 2026-04-28 17:16:47,100 INFO MainThread:14112 [wandb_run.py:_redirect():2495] Redirects installed.
19
+ 2026-04-28 17:16:47,104 INFO MainThread:14112 [wandb_init.py:init():1094] run started, returning control to user process
wandb/offline-run-20260428_132658-o9052r27/files/requirements.txt ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ cuda-toolkit==13.0.2
2
+ typing_extensions==4.15.0
3
+ nvidia-nvjitlink==13.0.88
4
+ MarkupSafe==3.0.3
5
+ nvidia-cufile==1.15.1.6
6
+ certifi==2026.4.22
7
+ nvidia-cusolver==12.0.4.66
8
+ nvidia-curand==10.4.0.35
9
+ Jinja2==3.1.6
10
+ nvidia-nvtx==13.0.85
11
+ nvidia-cuda-cupti==13.0.85
12
+ torchaudio==2.11.0+cu130
13
+ safetensors==0.7.0
14
+ nvidia-cuda-runtime==13.0.96
15
+ torchvision==0.26.0+cu130
16
+ nvidia-cufft==12.0.0.61
17
+ nvidia-cusparse==12.6.3.3
18
+ nvidia-cuda-nvrtc==13.0.88
19
+ fsspec==2026.2.0
20
+ nvidia-cusparselt-cu13==0.8.0
21
+ nvidia-nccl-cu13==2.28.9
22
+ nvidia-nvshmem-cu13==3.4.5
23
+ nvidia-cublas==13.1.0.3
24
+ nvidia-cudnn-cu13==9.19.0.56
25
+ mpmath==1.3.0
26
+ triton==3.6.0
27
+ networkx==3.6.1
28
+ sympy==1.14.0
29
+ torch==2.11.0+cu130
30
+ hf_transfer==0.1.9
31
+ six==1.17.0
32
+ typer==0.25.0
33
+ typing-inspection==0.4.2
34
+ muon-adamw8bit==0.5.0
35
+ aiosignal==1.4.0
36
+ wurlitzer==3.1.1
37
+ semantic-version==2.10.0
38
+ aiohappyeyeballs==2.6.1
39
+ cycler==0.12.1
40
+ tokenizers==0.22.2
41
+ annotated-doc==0.0.4
42
+ rpds-py==0.30.0
43
+ configobj==5.0.9
44
+ regex==2026.4.4
45
+ zipp==3.23.1
46
+ annotated-types==0.7.0
47
+ everett==3.1.0
48
+ pydantic_core==2.46.3
49
+ mdurl==0.1.2
50
+ platformdirs==4.9.6
51
+ idna==3.13
52
+ psutil==7.2.2
53
+ xxhash==3.7.0
54
+ smmap==5.0.3
55
+ frozenlist==1.8.0
56
+ multidict==6.7.1
57
+ shellingham==1.5.4
58
+ kiwisolver==1.5.0
59
+ propcache==0.4.1
60
+ h11==0.16.0
61
+ hf-xet==1.4.3
62
+ pyparsing==3.3.2
63
+ yarl==1.23.0
64
+ importlib_metadata==9.0.0
65
+ referencing==0.37.0
66
+ requests==2.33.1
67
+ filelock==3.29.0
68
+ charset-normalizer==3.4.7
69
+ wrapt==2.1.2
70
+ contourpy==1.3.3
71
+ python-box==6.1.0
72
+ python-dateutil==2.9.0.post0
73
+ packaging==26.2
74
+ httpx==0.28.1
75
+ PyYAML==6.0.3
76
+ click==8.3.3
77
+ jsonschema-specifications==2025.9.1
78
+ gitdb==4.0.12
79
+ einops==0.8.2
80
+ attrs==26.1.0
81
+ httpcore==1.0.9
82
+ cuda-pathfinder==1.5.4
83
+ requests-toolbelt==1.0.0
84
+ GitPython==3.1.48
85
+ jsonschema==4.26.0
86
+ tqdm==4.67.3
87
+ urllib3==2.6.3
88
+ anyio==4.13.0
89
+ simplejson==4.1.1
90
+ multiprocess==0.70.19
91
+ dill==0.4.1
92
+ protobuf==7.34.1
93
+ markdown-it-py==4.0.0
94
+ bitsandbytes==0.49.2
95
+ cuda-bindings==13.2.0
96
+ aiohttp==3.13.5
97
+ accelerate==1.13.0
98
+ dulwich==0.25.2
99
+ pydantic==2.13.3
100
+ datasets==4.8.5
101
+ rich==15.0.0
102
+ flash-linear-attention==0.5.0
103
+ pillow==12.2.0
104
+ huggingface_hub==1.12.0
105
+ sentry-sdk==2.58.0
106
+ fla-core==0.5.0
107
+ Pygments==2.20.0
108
+ diffusers==0.37.1
109
+ fonttools==4.62.1
110
+ comet_ml==3.57.3
111
+ setuptools==81.0.0
112
+ matplotlib==3.10.9
113
+ pyarrow==24.0.0
114
+ wandb==0.26.1
115
+ numpy==2.4.4
116
+ pandas==3.0.2
117
+ transformers==5.6.2
wandb/offline-run-20260428_132658-o9052r27/logs/debug-core.log ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {"time":"2026-04-28T13:26:58.701599632Z","level":"INFO","msg":"main: starting server","port-filename":"/tmp/tmpll17c794/port-6681.txt","pid":6681,"detached":false,"idle-timeout":600000000000,"log-level":0,"disable-analytics":false,"shutdown-on-parent-exit":false,"enable-dcgm-profiling":false}
2
+ {"time":"2026-04-28T13:26:58.704326543Z","level":"INFO","msg":"server: will exit if parent process dies","ppid":6681}
3
+ {"time":"2026-04-28T13:26:58.70424692Z","level":"INFO","msg":"server: accepting connections","addr":{"Name":"/tmp/wandb-6681-6712-3956627621/socket","Net":"unix"}}
4
+ {"time":"2026-04-28T13:26:58.806869406Z","level":"INFO","msg":"connection: ManageConnectionData: new connection created","id":"1(@)"}
5
+ {"time":"2026-04-28T13:26:58.828765063Z","level":"INFO","msg":"handleInformInit: received","streamId":"o9052r27","id":"1(@)"}
6
+ {"time":"2026-04-28T13:26:58.960660655Z","level":"INFO","msg":"handleInformInit: stream started","streamId":"o9052r27","id":"1(@)"}
7
+ {"time":"2026-04-28T13:27:04.392467558Z","level":"INFO","msg":"handleInformTeardown: server teardown initiated","id":"1(@)"}
8
+ {"time":"2026-04-28T13:27:04.392527721Z","level":"INFO","msg":"server is shutting down"}
9
+ {"time":"2026-04-28T13:27:04.392535141Z","level":"INFO","msg":"connection: closing","id":"1(@)"}
10
+ {"time":"2026-04-28T13:27:04.392635535Z","level":"INFO","msg":"connection: closed successfully","id":"1(@)"}
11
+ {"time":"2026-04-28T13:27:04.392627225Z","level":"INFO","msg":"server: listener closed","addr":{"Name":"/tmp/wandb-6681-6712-3956627621/socket","Net":"unix"}}
12
+ {"time":"2026-04-28T13:27:04.421552415Z","level":"INFO","msg":"handleInformTeardown: server shutdown complete","id":"1(@)"}
13
+ {"time":"2026-04-28T13:27:04.421573556Z","level":"INFO","msg":"connection: ManageConnectionData: connection closed","id":"1(@)"}
14
+ {"time":"2026-04-28T13:27:04.421579966Z","level":"INFO","msg":"server is closed"}
wandb/offline-run-20260428_132658-o9052r27/logs/debug-internal.log ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {"time":"2026-04-28T13:26:58.829048314Z","level":"INFO","msg":"wandb-core"}
2
+ {"time":"2026-04-28T13:26:58.829092766Z","level":"INFO","msg":"stream: starting","core version":"0.26.1"}
3
+ {"time":"2026-04-28T13:26:58.960323061Z","level":"WARN","msg":"featurechecker: GraphQL client is nil, skipping feature loading"}
4
+ {"time":"2026-04-28T13:26:58.960354402Z","level":"WARN","msg":"featurechecker: GraphQL client is nil, skipping feature loading"}
5
+ {"time":"2026-04-28T13:26:58.960391424Z","level":"INFO","msg":"stream: created new stream","id":"o9052r27"}
6
+ {"time":"2026-04-28T13:26:58.960480497Z","level":"INFO","msg":"handler: started"}
7
+ {"time":"2026-04-28T13:26:58.960646764Z","level":"INFO","msg":"stream: started"}
8
+ {"time":"2026-04-28T13:26:58.960704477Z","level":"INFO","msg":"writer: started","stream_id":"o9052r27"}
9
+ {"time":"2026-04-28T13:26:58.960767929Z","level":"INFO","msg":"sender: started"}
10
+ {"time":"2026-04-28T13:26:58.975123911Z","level":"WARN","msg":"featurechecker: GraphQL client is nil, skipping feature loading"}
11
+ {"time":"2026-04-28T13:26:58.975175533Z","level":"WARN","msg":"runupserter: server does not expand metric globs but the x_server_side_expand_glob_metrics setting is set; ignoring"}
12
+ {"time":"2026-04-28T13:27:04.392744599Z","level":"INFO","msg":"stream: finishing up"}
13
+ {"time":"2026-04-28T13:27:04.39276658Z","level":"INFO","msg":"handler: closed"}
14
+ {"time":"2026-04-28T13:27:04.392811252Z","level":"INFO","msg":"sender: closed"}
15
+ {"time":"2026-04-28T13:27:04.392819012Z","level":"INFO","msg":"stream: all finished"}
wandb/offline-run-20260428_132658-o9052r27/logs/debug.log ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 2026-04-28 13:26:58,591 INFO MainThread:6681 [wandb_setup.py:_flush():81] Current SDK version is 0.26.1
2
+ 2026-04-28 13:26:58,591 INFO MainThread:6681 [wandb_setup.py:_flush():81] Configure stats pid to 6681
3
+ 2026-04-28 13:26:58,591 INFO MainThread:6681 [wandb_setup.py:_flush():81] Loading settings from environment variables
4
+ 2026-04-28 13:26:58,591 INFO MainThread:6681 [wandb_init.py:setup_run_log_directory():723] Logging user logs to /root/sdxs-2b/wandb/offline-run-20260428_132658-o9052r27/logs/debug.log
5
+ 2026-04-28 13:26:58,592 INFO MainThread:6681 [wandb_init.py:setup_run_log_directory():724] Logging internal logs to /root/sdxs-2b/wandb/offline-run-20260428_132658-o9052r27/logs/debug-internal.log
6
+ 2026-04-28 13:26:58,592 INFO MainThread:6681 [wandb_init.py:init():850] calling init triggers
7
+ 2026-04-28 13:26:58,592 INFO MainThread:6681 [wandb_init.py:init():855] wandb.init called with sweep_config: {}
8
+ config: {'batch_size': 7, 'base_learning_rate': 1.3333333333333335e-05, 'num_epochs': 1, 'optimizer_type': 'adafactor', '_wandb': {}}
9
+ 2026-04-28 13:26:58,592 INFO MainThread:6681 [wandb_init.py:init():898] starting backend
10
+ 2026-04-28 13:26:58,807 INFO MainThread:6681 [wandb_init.py:init():913] sending inform_init request
11
+ 2026-04-28 13:26:58,961 INFO MainThread:6681 [wandb_init.py:init():918] backend started and connected
12
+ 2026-04-28 13:26:58,964 INFO MainThread:6681 [wandb_init.py:init():988] updated telemetry
13
+ 2026-04-28 13:26:58,971 INFO MainThread:6681 [wandb_init.py:init():1011] communicating run to backend with 90.0 second timeout
14
+ 2026-04-28 13:26:58,977 INFO MainThread:6681 [wandb_init.py:init():1056] starting run threads in backend
15
+ 2026-04-28 13:26:59,098 INFO MainThread:6681 [wandb_run.py:_console_start():2554] atexit reg
16
+ 2026-04-28 13:26:59,098 INFO MainThread:6681 [wandb_run.py:_redirect():2403] redirect: wrap_raw
17
+ 2026-04-28 13:26:59,099 INFO MainThread:6681 [wandb_run.py:_redirect():2472] Wrapping output streams.
18
+ 2026-04-28 13:26:59,099 INFO MainThread:6681 [wandb_run.py:_redirect():2495] Redirects installed.
19
+ 2026-04-28 13:26:59,115 INFO MainThread:6681 [wandb_init.py:init():1094] run started, returning control to user process
20
+ 2026-04-28 13:27:04,393 INFO wandb-AsyncioManager-main:6681 [service_client.py:_forward_responses():134] Reached EOF.
21
+ 2026-04-28 13:27:04,393 INFO wandb-AsyncioManager-main:6681 [mailbox.py:close():155] Closing mailbox, abandoning 0 handles.
wandb/offline-run-20260428_132658-o9052r27/run-o9052r27.wandb ADDED
Binary file (6.41 kB). View file
 
wandb/run-20260428_171645-wt40fdyx/files/output.log ADDED
@@ -0,0 +1,385 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ The config attributes {'final_sigmas_type': 'sigma_min', 'sigma_data': 1.0, 'sigma_max': 80.0, 'sigma_min': 0.002} were passed to FlowMatchEulerDiscreteScheduler, but are not expected and will be ignored. Please verify your scheduler_config.json configuration file.
2
+ [transformers] The fast path is not available because one of the required library is not installed. Falling back to torch implementation. To install follow https://github.com/fla-org/flash-linear-attention#installation and https://github.com/Dao-AILab/causal-conv1d
3
+ Loading weights: 100%|████████████████████████████████████████████████████████████████████████████████████| 473/473 [00:00<00:00, 766.86it/s]
4
+ images: 233407
5
+ Total samples 14580
6
+ Загружаем Transformer из чекпоинта: transformer
7
+ --- РАЗМОРОЖЕННЫЕ СЛОИ ---
8
+ --------------------------
9
+
10
+ [ПОЧИНКА FIXED SAMPLES] Отсекаем мусор: torch.Size([1, 16, 16, 44, 80])
11
+ [ОТЛАДКА ДАТАСЕТА] latents final shape: torch.Size([1, 16, 1, 44, 80])
12
+
13
+ [ПОЧИНКА FIXED SAMPLES] Отсекаем мусор: torch.Size([1, 16, 16, 48, 80])
14
+ [ОТЛАДКА ДАТАСЕТА] latents final shape: torch.Size([1, 16, 1, 48, 80])
15
+
16
+ [ПОЧИНКА FIXED SAMPLES] Отсекаем мусор: torch.Size([1, 16, 16, 80, 40])
17
+ [ОТЛАДКА ДАТАСЕТА] latents final shape: torch.Size([1, 16, 1, 80, 40])
18
+
19
+ [ПОЧИНКА FIXED SAMPLES] Отсекаем мусор: torch.Size([1, 16, 16, 80, 44])
20
+ [ОТЛАДКА ДАТАСЕТА] latents final shape: torch.Size([1, 16, 1, 80, 44])
21
+
22
+ [ПОЧИНКА FIXED SAMPLES] Отсекаем мусор: torch.Size([1, 16, 16, 80, 48])
23
+ [ОТЛАДКА ДАТАСЕТА] latents final shape: torch.Size([1, 16, 1, 80, 48])
24
+
25
+ [ПОЧИНКА FIXED SAMPLES] Отсекаем мусор: torch.Size([1, 16, 16, 80, 52])
26
+ [ОТЛАДКА ДАТАСЕТА] latents final shape: torch.Size([1, 16, 1, 80, 52])
27
+
28
+ [ПОЧИНКА FIXED SAMPLES] Отсекаем мусор: torch.Size([1, 16, 16, 80, 56])
29
+ [ОТЛАДКА ДАТАСЕТА] latents final shape: torch.Size([1, 16, 1, 80, 56])
30
+
31
+ [ПОЧИНКА FIXED SAMPLES] Отсекаем мусор: torch.Size([1, 16, 16, 52, 80])
32
+ [ОТЛАДКА ДАТАСЕТА] latents final shape: torch.Size([1, 16, 1, 52, 80])
33
+
34
+ [ПОЧИНКА FIXED SAMPLES] Отсекаем мусор: torch.Size([1, 16, 16, 56, 80])
35
+ [ОТЛАДКА ДАТАСЕТА] latents final shape: torch.Size([1, 16, 1, 56, 80])
36
+
37
+ [ПОЧИНКА FIXED SAMPLES] Отсекаем мусор: torch.Size([1, 16, 16, 60, 80])
38
+ [ОТЛАДКА ДАТАСЕТА] latents final shape: torch.Size([1, 16, 1, 60, 80])
39
+
40
+ [ПОЧИНКА FIXED SAMPLES] Отсекаем мусор: torch.Size([1, 16, 16, 80, 60])
41
+ [ОТЛАДКА ДАТАСЕТА] latents final shape: torch.Size([1, 16, 1, 80, 60])
42
+
43
+ [ПОЧИНКА FIXED SAMPLES] Отсекаем мусор: torch.Size([1, 16, 16, 64, 80])
44
+ [ОТЛАДКА ДАТАСЕТА] latents final shape: torch.Size([1, 16, 1, 64, 80])
45
+
46
+ [ПОЧИНКА FIXED SAMPLES] Отсекаем мусор: torch.Size([1, 16, 16, 68, 80])
47
+ [ОТЛАДКА ДАТАСЕТА] latents final shape: torch.Size([1, 16, 1, 68, 80])
48
+
49
+ [ПОЧИНКА FIXED SAMPLES] Отсекаем мусор: torch.Size([1, 16, 16, 72, 80])
50
+ [ОТЛАДКА ДАТАСЕТА] latents final shape: torch.Size([1, 16, 1, 72, 80])
51
+
52
+ [ПОЧИНКА FIXED SAMPLES] Отсекаем мусор: torch.Size([1, 16, 16, 76, 80])
53
+ [ОТЛАДКА ДАТАСЕТА] latents final shape: torch.Size([1, 16, 1, 76, 80])
54
+
55
+ [ПОЧИНКА FIXED SAMPLES] Отсекаем мусор: torch.Size([1, 16, 16, 80, 64])
56
+ [ОТЛАДКА ДАТАСЕТА] latents final shape: torch.Size([1, 16, 1, 80, 64])
57
+
58
+ [ПОЧИНКА FIXED SAMPLES] Отсекаем мусор: torch.Size([1, 16, 16, 80, 68])
59
+ [ОТЛАДКА ДАТАСЕТА] latents final shape: torch.Size([1, 16, 1, 80, 68])
60
+
61
+ [ПОЧИНКА FIXED SAMPLES] Отсекаем мусор: torch.Size([1, 16, 16, 80, 72])
62
+ [ОТЛАДКА ДАТАСЕТА] latents final shape: torch.Size([1, 16, 1, 80, 72])
63
+
64
+ [ПОЧИНКА FIXED SAMPLES] Отсекаем мусор: torch.Size([1, 16, 16, 80, 76])
65
+ [ОТЛАДКА ДАТАСЕТА] latents final shape: torch.Size([1, 16, 1, 80, 76])
66
+
67
+ [ПОЧИНКА FIXED SAMPLES] Отсекаем мусор: torch.Size([1, 16, 16, 40, 80])
68
+ [ОТЛАДКА ДАТАСЕТА] latents final shape: torch.Size([1, 16, 1, 40, 80])
69
+ Создано 20 групп фиксированных семплов по разрешениям
70
+ Генерация сэмплов до старта обучения...
71
+ /usr/lib/python3.12/contextlib.py:105: FutureWarning: `torch.backends.cuda.sdp_kernel()` is deprecated. In the future, this context manager will be removed. Please see `torch.nn.attention.sdpa_kernel()` for the new context manager, with updated signature.
72
+ self.gen = func(*args, **kwds)
73
+
74
+ ==================================================
75
+ [ОТЛАДКА VAE DECODE НА НУЛЕВОМ ШАГЕ]
76
+ 1. current_latents: shape=torch.Size([1, 16, 1, 44, 80])
77
+ min=-1.9811, max=2.2364, std=0.6226
78
+ 2. l_std shape=torch.Size([1, 16, 1, 1, 1]), l_mean shape=torch.Size([1, 16, 1, 1, 1])
79
+ sigma_data=1.0
80
+ 3. latents_for_decode: shape=torch.Size([1, 16, 1, 44, 80])
81
+ min=-3.9773, max=3.4072, std=1.1547
82
+ 4. decoded_fp32 (после VAE): shape=torch.Size([1, 3, 352, 640])
83
+ min=-1.0000, max=0.9945, std=0.6423
84
+ ==================================================
85
+
86
+
87
+ ==================================================
88
+ [ОТЛАДКА VAE DECODE НА НУЛЕВОМ ШАГЕ]
89
+ 1. current_latents: shape=torch.Size([1, 16, 1, 48, 80])
90
+ min=-2.1993, max=1.9178, std=0.5500
91
+ 2. l_std shape=torch.Size([1, 16, 1, 1, 1]), l_mean shape=torch.Size([1, 16, 1, 1, 1])
92
+ sigma_data=1.0
93
+ 3. latents_for_decode: shape=torch.Size([1, 16, 1, 48, 80])
94
+ min=-3.5397, max=3.3824, std=1.0561
95
+ 4. decoded_fp32 (после VAE): shape=torch.Size([1, 3, 384, 640])
96
+ min=-1.0000, max=1.0000, std=0.3971
97
+ ==================================================
98
+
99
+
100
+ ==================================================
101
+ [ОТЛАДКА VAE DECODE НА НУЛЕВОМ ШАГЕ]
102
+ 1. current_latents: shape=torch.Size([1, 16, 1, 80, 40])
103
+ min=-2.7174, max=2.0244, std=0.6368
104
+ 2. l_std shape=torch.Size([1, 16, 1, 1, 1]), l_mean shape=torch.Size([1, 16, 1, 1, 1])
105
+ sigma_data=1.0
106
+ 3. latents_for_decode: shape=torch.Size([1, 16, 1, 80, 40])
107
+ min=-4.0544, max=4.0678, std=1.1537
108
+ 4. decoded_fp32 (после VAE): shape=torch.Size([1, 3, 640, 320])
109
+ min=-0.9997, max=1.0000, std=0.5404
110
+ ==================================================
111
+
112
+
113
+ ==================================================
114
+ [ОТЛАДКА VAE DECODE НА НУЛЕВОМ ШАГЕ]
115
+ 1. current_latents: shape=torch.Size([1, 16, 1, 80, 44])
116
+ min=-2.0394, max=2.0944, std=0.5736
117
+ 2. l_std shape=torch.Size([1, 16, 1, 1, 1]), l_mean shape=torch.Size([1, 16, 1, 1, 1])
118
+ sigma_data=1.0
119
+ 3. latents_for_decode: shape=torch.Size([1, 16, 1, 80, 44])
120
+ min=-3.8287, max=3.3714, std=1.0290
121
+ 4. decoded_fp32 (после VAE): shape=torch.Size([1, 3, 640, 352])
122
+ min=-1.0000, max=1.0000, std=0.4719
123
+ ==================================================
124
+
125
+
126
+ ==================================================
127
+ [ОТЛАДКА VAE DECODE НА НУЛЕВОМ ШАГЕ]
128
+ 1. current_latents: shape=torch.Size([1, 16, 1, 80, 48])
129
+ min=-2.0441, max=1.9221, std=0.5108
130
+ 2. l_std shape=torch.Size([1, 16, 1, 1, 1]), l_mean shape=torch.Size([1, 16, 1, 1, 1])
131
+ sigma_data=1.0
132
+ 3. latents_for_decode: shape=torch.Size([1, 16, 1, 80, 48])
133
+ min=-3.4324, max=3.7347, std=0.9750
134
+ 4. decoded_fp32 (после VAE): shape=torch.Size([1, 3, 640, 384])
135
+ min=-1.0000, max=1.0000, std=0.5049
136
+ ==================================================
137
+
138
+
139
+ ==================================================
140
+ [ОТЛАДКА VAE DECODE НА НУЛЕВОМ ШАГЕ]
141
+ 1. current_latents: shape=torch.Size([1, 16, 1, 80, 52])
142
+ min=-2.0292, max=2.2682, std=0.7043
143
+ 2. l_std shape=torch.Size([1, 16, 1, 1, 1]), l_mean shape=torch.Size([1, 16, 1, 1, 1])
144
+ sigma_data=1.0
145
+ 3. latents_for_decode: shape=torch.Size([1, 16, 1, 80, 52])
146
+ min=-4.1673, max=4.4971, std=1.3949
147
+ 4. decoded_fp32 (после VAE): shape=torch.Size([1, 3, 640, 416])
148
+ min=-1.0000, max=1.0000, std=0.6222
149
+ ==================================================
150
+
151
+
152
+ ==================================================
153
+ [ОТЛАДКА VAE DECODE НА НУЛЕВОМ ШАГЕ]
154
+ 1. current_latents: shape=torch.Size([1, 16, 1, 80, 56])
155
+ min=-1.7528, max=1.6711, std=0.6432
156
+ 2. l_std shape=torch.Size([1, 16, 1, 1, 1]), l_mean shape=torch.Size([1, 16, 1, 1, 1])
157
+ sigma_data=1.0
158
+ 3. latents_for_decode: shape=torch.Size([1, 16, 1, 80, 56])
159
+ min=-4.0104, max=4.1834, std=1.4406
160
+ 4. decoded_fp32 (после VAE): shape=torch.Size([1, 3, 640, 448])
161
+ min=-0.9654, max=1.0000, std=0.4818
162
+ ==================================================
163
+
164
+
165
+ ==================================================
166
+ [ОТЛАДКА VAE DECODE НА НУЛЕВОМ ШАГЕ]
167
+ 1. current_latents: shape=torch.Size([1, 16, 1, 52, 80])
168
+ min=-2.0965, max=2.2269, std=0.4286
169
+ 2. l_std shape=torch.Size([1, 16, 1, 1, 1]), l_mean shape=torch.Size([1, 16, 1, 1, 1])
170
+ sigma_data=1.0
171
+ 3. latents_for_decode: shape=torch.Size([1, 16, 1, 52, 80])
172
+ min=-3.3608, max=2.9338, std=0.9200
173
+ 4. decoded_fp32 (после VAE): shape=torch.Size([1, 3, 416, 640])
174
+ min=-1.0000, max=0.9774, std=0.3019
175
+ ==================================================
176
+
177
+
178
+ ==================================================
179
+ [ОТЛАДКА VAE DECODE НА НУЛЕВОМ ШАГЕ]
180
+ 1. current_latents: shape=torch.Size([1, 16, 1, 56, 80])
181
+ min=-2.3215, max=2.6622, std=0.6174
182
+ 2. l_std shape=torch.Size([1, 16, 1, 1, 1]), l_mean shape=torch.Size([1, 16, 1, 1, 1])
183
+ sigma_data=1.0
184
+ 3. latents_for_decode: shape=torch.Size([1, 16, 1, 56, 80])
185
+ min=-3.6939, max=4.7696, std=1.3130
186
+ 4. decoded_fp32 (после VAE): shape=torch.Size([1, 3, 448, 640])
187
+ min=-1.0000, max=1.0000, std=0.4811
188
+ ==================================================
189
+
190
+
191
+ ==================================================
192
+ [ОТЛАДКА VAE DECODE НА НУЛЕВОМ ШАГЕ]
193
+ 1. current_latents: shape=torch.Size([1, 16, 1, 60, 80])
194
+ min=-2.2899, max=2.1393, std=0.5506
195
+ 2. l_std shape=torch.Size([1, 16, 1, 1, 1]), l_mean shape=torch.Size([1, 16, 1, 1, 1])
196
+ sigma_data=1.0
197
+ 3. latents_for_decode: shape=torch.Size([1, 16, 1, 60, 80])
198
+ min=-4.0351, max=4.0100, std=1.1577
199
+ 4. decoded_fp32 (после VAE): shape=torch.Size([1, 3, 480, 640])
200
+ min=-1.0000, max=1.0000, std=0.6317
201
+ ==================================================
202
+
203
+
204
+ ==================================================
205
+ [ОТЛАДКА VAE DECODE НА НУЛЕВОМ ШАГЕ]
206
+ 1. current_latents: shape=torch.Size([1, 16, 1, 80, 60])
207
+ min=-1.8058, max=2.0032, std=0.5188
208
+ 2. l_std shape=torch.Size([1, 16, 1, 1, 1]), l_mean shape=torch.Size([1, 16, 1, 1, 1])
209
+ sigma_data=1.0
210
+ 3. latents_for_decode: shape=torch.Size([1, 16, 1, 80, 60])
211
+ min=-3.2342, max=3.6659, std=1.0352
212
+ 4. decoded_fp32 (после VAE): shape=torch.Size([1, 3, 640, 480])
213
+ min=-1.0000, max=1.0000, std=0.6372
214
+ ==================================================
215
+
216
+
217
+ ==================================================
218
+ [ОТЛАДКА VAE DECODE НА НУЛЕВОМ ШАГЕ]
219
+ 1. current_latents: shape=torch.Size([1, 16, 1, 64, 80])
220
+ min=-2.1774, max=2.1568, std=0.6666
221
+ 2. l_std shape=torch.Size([1, 16, 1, 1, 1]), l_mean shape=torch.Size([1, 16, 1, 1, 1])
222
+ sigma_data=1.0
223
+ 3. latents_for_decode: shape=torch.Size([1, 16, 1, 64, 80])
224
+ min=-4.7810, max=5.1935, std=1.3580
225
+ 4. decoded_fp32 (после VAE): shape=torch.Size([1, 3, 512, 640])
226
+ min=-1.0000, max=1.0000, std=0.5784
227
+ ==================================================
228
+
229
+
230
+ ==================================================
231
+ [ОТЛАДКА VAE DECODE НА НУЛЕВОМ ШАГЕ]
232
+ 1. current_latents: shape=torch.Size([1, 16, 1, 68, 80])
233
+ min=-1.9091, max=2.1057, std=0.5661
234
+ 2. l_std shape=torch.Size([1, 16, 1, 1, 1]), l_mean shape=torch.Size([1, 16, 1, 1, 1])
235
+ sigma_data=1.0
236
+ 3. latents_for_decode: shape=torch.Size([1, 16, 1, 68, 80])
237
+ min=-3.4599, max=3.7540, std=1.0538
238
+ 4. decoded_fp32 (после VAE): shape=torch.Size([1, 3, 544, 640])
239
+ min=-1.0000, max=1.0000, std=0.6665
240
+ ==================================================
241
+
242
+
243
+ ==================================================
244
+ [ОТЛАДКА VAE DECODE НА НУЛЕВОМ ШАГЕ]
245
+ 1. current_latents: shape=torch.Size([1, 16, 1, 72, 80])
246
+ min=-2.1917, max=2.3725, std=0.6957
247
+ 2. l_std shape=torch.Size([1, 16, 1, 1, 1]), l_mean shape=torch.Size([1, 16, 1, 1, 1])
248
+ sigma_data=1.0
249
+ 3. latents_for_decode: shape=torch.Size([1, 16, 1, 72, 80])
250
+ min=-3.8205, max=4.1090, std=1.5053
251
+ 4. decoded_fp32 (после VAE): shape=torch.Size([1, 3, 576, 640])
252
+ min=-1.0000, max=1.0000, std=0.6376
253
+ ==================================================
254
+
255
+
256
+ ==================================================
257
+ [ОТЛАДКА VAE DECODE НА НУЛЕВОМ ШАГЕ]
258
+ 1. current_latents: shape=torch.Size([1, 16, 1, 76, 80])
259
+ min=-2.3168, max=2.0439, std=0.6811
260
+ 2. l_std shape=torch.Size([1, 16, 1, 1, 1]), l_mean shape=torch.Size([1, 16, 1, 1, 1])
261
+ sigma_data=1.0
262
+ 3. latents_for_decode: shape=torch.Size([1, 16, 1, 76, 80])
263
+ min=-3.8838, max=4.5797, std=1.3369
264
+ 4. decoded_fp32 (после VAE): shape=torch.Size([1, 3, 608, 640])
265
+ min=-1.0000, max=1.0000, std=0.6667
266
+ ==================================================
267
+
268
+
269
+ ==================================================
270
+ [ОТЛАДКА VAE DECODE НА НУЛЕВОМ ШАГЕ]
271
+ 1. current_latents: shape=torch.Size([1, 16, 1, 80, 64])
272
+ min=-2.2767, max=2.3007, std=0.5141
273
+ 2. l_std shape=torch.Size([1, 16, 1, 1, 1]), l_mean shape=torch.Size([1, 16, 1, 1, 1])
274
+ sigma_data=1.0
275
+ 3. latents_for_decode: shape=torch.Size([1, 16, 1, 80, 64])
276
+ min=-3.7021, max=3.3769, std=0.8752
277
+ 4. decoded_fp32 (после VAE): shape=torch.Size([1, 3, 640, 512])
278
+ min=-1.0000, max=1.0000, std=0.4680
279
+ ==================================================
280
+
281
+
282
+ ==================================================
283
+ [ОТЛАДКА VAE DECODE НА НУЛЕВОМ ШАГЕ]
284
+ 1. current_latents: shape=torch.Size([1, 16, 1, 80, 68])
285
+ min=-2.3068, max=2.3424, std=0.7115
286
+ 2. l_std shape=torch.Size([1, 16, 1, 1, 1]), l_mean shape=torch.Size([1, 16, 1, 1, 1])
287
+ sigma_data=1.0
288
+ 3. latents_for_decode: shape=torch.Size([1, 16, 1, 80, 68])
289
+ min=-3.9636, max=4.6402, std=1.4684
290
+ 4. decoded_fp32 (после VAE): shape=torch.Size([1, 3, 640, 544])
291
+ min=-0.9553, max=1.0000, std=0.4083
292
+ ==================================================
293
+
294
+
295
+ ==================================================
296
+ [ОТЛАДКА VAE DECODE НА НУЛЕВОМ ШАГЕ]
297
+ 1. current_latents: shape=torch.Size([1, 16, 1, 80, 72])
298
+ min=-2.3526, max=2.5922, std=0.7641
299
+ 2. l_std shape=torch.Size([1, 16, 1, 1, 1]), l_mean shape=torch.Size([1, 16, 1, 1, 1])
300
+ sigma_data=1.0
301
+ 3. latents_for_decode: shape=torch.Size([1, 16, 1, 80, 72])
302
+ min=-4.1452, max=4.7889, std=1.6258
303
+ 4. decoded_fp32 (после VAE): shape=torch.Size([1, 3, 640, 576])
304
+ min=-1.0000, max=1.0000, std=0.7539
305
+ ==================================================
306
+
307
+
308
+ ==================================================
309
+ [ОТЛАДКА VAE DECODE НА НУЛЕВОМ ШАГЕ]
310
+ 1. current_latents: shape=torch.Size([1, 16, 1, 80, 76])
311
+ min=-1.7528, max=1.9838, std=0.4715
312
+ 2. l_std shape=torch.Size([1, 16, 1, 1, 1]), l_mean shape=torch.Size([1, 16, 1, 1, 1])
313
+ sigma_data=1.0
314
+ 3. latents_for_decode: shape=torch.Size([1, 16, 1, 80, 76])
315
+ min=-3.3567, max=3.4733, std=1.0891
316
+ 4. decoded_fp32 (после VAE): shape=torch.Size([1, 3, 640, 608])
317
+ min=-0.9685, max=0.9913, std=0.4626
318
+ ==================================================
319
+
320
+
321
+ ==================================================
322
+ [ОТЛАДКА VAE DECODE НА НУЛЕВОМ ШАГЕ]
323
+ 1. current_latents: shape=torch.Size([1, 16, 1, 40, 80])
324
+ min=-2.1861, max=2.0078, std=0.5870
325
+ 2. l_std shape=torch.Size([1, 16, 1, 1, 1]), l_mean shape=torch.Size([1, 16, 1, 1, 1])
326
+ sigma_data=1.0
327
+ 3. latents_for_decode: shape=torch.Size([1, 16, 1, 40, 80])
328
+ min=-4.1700, max=4.9017, std=1.3442
329
+ 4. decoded_fp32 (после VAE): shape=torch.Size([1, 3, 320, 640])
330
+ min=-0.8055, max=0.9961, std=0.3967
331
+ ==================================================
332
+
333
+ Total steps per GPU: 14580
334
+ Training: 36%|████████████████████████████▉ | 5205/14580 [15:41:32<27:32:51, 10.58s/step]
335
+ Saving: True | Max: 0.1232 | Last: 0.1094 | Avg: 0.1232
336
+ Saving: True | Max: 0.1423 | Last: 0.1423 | Avg: 0.1245
337
+ Saving: True | Max: 0.1248 | Last: 0.0964 | Avg: 0.1248
338
+ Saving: True | Max: 0.1423 | Last: 0.1423 | Avg: 0.1246
339
+ Saving: True | Max: 0.1417 | Last: 0.1417 | Avg: 0.1236
340
+ Saving: True | Max: 0.1405 | Last: 0.1405 | Avg: 0.1234
341
+ Saving: True | Max: 0.1534 | Last: 0.1534 | Avg: 0.1232
342
+ Saving: True | Max: 0.1559 | Last: 0.1559 | Avg: 0.1234
343
+ Saving: True | Max: 0.1231 | Last: 0.1104 | Avg: 0.1231
344
+ Saving: True | Max: 0.1229 | Last: 0.1139 | Avg: 0.1229
345
+ Saving: True | Max: 0.1228 | Last: 0.1152 | Avg: 0.1228
346
+ Saving: True | Max: 0.1485 | Last: 0.1485 | Avg: 0.1228
347
+ Saving: True | Max: 0.1231 | Last: 0.0736 | Avg: 0.1231
348
+ Saving: True | Max: 0.1232 | Last: 0.1016 | Avg: 0.1232
349
+ Saving: True | Max: 0.1519 | Last: 0.1519 | Avg: 0.1234
350
+ Saving: True | Max: 0.1233 | Last: 0.1096 | Avg: 0.1233
351
+ Saving: True | Max: 0.1232 | Last: 0.1051 | Avg: 0.1232
352
+ Saving: True | Max: 0.1234 | Last: 0.1173 | Avg: 0.1234
353
+ Saving: True | Max: 0.1233 | Last: 0.1168 | Avg: 0.1233
354
+ Saving: True | Max: 0.1309 | Last: 0.1309 | Avg: 0.1229
355
+ Saving: True | Max: 0.1432 | Last: 0.1432 | Avg: 0.1227
356
+ Saving: True | Max: 0.1226 | Last: 0.1211 | Avg: 0.1226
357
+ Saving: True | Max: 0.1227 | Last: 0.1227 | Avg: 0.1221
358
+ Saving: True | Max: 0.1219 | Last: 0.1029 | Avg: 0.1219
359
+ Saving: True | Max: 0.1217 | Last: 0.1058 | Avg: 0.1217
360
+ Saving: True | Max: 0.1218 | Last: 0.1206 | Avg: 0.1218
361
+ Saving: True | Max: 0.1379 | Last: 0.1379 | Avg: 0.1221
362
+ Saving: True | Max: 0.1228 | Last: 0.1012 | Avg: 0.1228
363
+ Saving: True | Max: 0.1226 | Last: 0.1121 | Avg: 0.1226
364
+ Saving: True | Max: 0.1226 | Last: 0.0930 | Avg: 0.1226
365
+ Saving: False | Max: 0.1564 | Last: 0.1564 | Avg: 0.1230
366
+ Saving: True | Max: 0.1266 | Last: 0.1266 | Avg: 0.1234
367
+ Saving: True | Max: 0.1234 | Last: 0.1050 | Avg: 0.1234
368
+ Saving: True | Max: 0.1235 | Last: 0.1031 | Avg: 0.1235
369
+ Saving: True | Max: 0.1235 | Last: 0.0956 | Avg: 0.1235
370
+ Saving: True | Max: 0.1233 | Last: 0.1117 | Avg: 0.1233
371
+ Saving: False | Max: 0.1559 | Last: 0.1559 | Avg: 0.1229
372
+ Saving: True | Max: 0.1532 | Last: 0.1532 | Avg: 0.1234
373
+ Saving: True | Max: 0.1248 | Last: 0.1248 | Avg: 0.1231
374
+ Saving: True | Max: 0.1445 | Last: 0.1445 | Avg: 0.1228
375
+ Saving: True | Max: 0.1514 | Last: 0.1514 | Avg: 0.1229
376
+ Saving: True | Max: 0.1225 | Last: 0.1021 | Avg: 0.1225
377
+ Saving: True | Max: 0.1317 | Last: 0.1317 | Avg: 0.1221
378
+ Saving: True | Max: 0.1220 | Last: 0.1002 | Avg: 0.1220
379
+ Saving: True | Max: 0.1321 | Last: 0.1321 | Avg: 0.1221
380
+ Saving: True | Max: 0.1260 | Last: 0.1260 | Avg: 0.1218
381
+ Saving: True | Max: 0.1212 | Last: 0.1191 | Avg: 0.1212
382
+ Saving: True | Max: 0.1213 | Last: 0.1155 | Avg: 0.1213
383
+ Saving: True | Max: 0.1212 | Last: 0.1138 | Avg: 0.1212
384
+ Saving: True | Max: 0.1213 | Last: 0.1184 | Avg: 0.1213
385
+ Saving: True | Max: 0.1371 | Last: 0.1371 | Avg: 0.1215
wandb/run-20260428_171645-wt40fdyx/files/requirements.txt ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ cuda-toolkit==13.0.2
2
+ typing_extensions==4.15.0
3
+ nvidia-nvjitlink==13.0.88
4
+ MarkupSafe==3.0.3
5
+ nvidia-cufile==1.15.1.6
6
+ certifi==2026.4.22
7
+ nvidia-cusolver==12.0.4.66
8
+ nvidia-curand==10.4.0.35
9
+ Jinja2==3.1.6
10
+ nvidia-nvtx==13.0.85
11
+ nvidia-cuda-cupti==13.0.85
12
+ torchaudio==2.11.0+cu130
13
+ safetensors==0.7.0
14
+ nvidia-cuda-runtime==13.0.96
15
+ torchvision==0.26.0+cu130
16
+ nvidia-cufft==12.0.0.61
17
+ nvidia-cusparse==12.6.3.3
18
+ nvidia-cuda-nvrtc==13.0.88
19
+ fsspec==2026.2.0
20
+ nvidia-cusparselt-cu13==0.8.0
21
+ nvidia-nccl-cu13==2.28.9
22
+ nvidia-nvshmem-cu13==3.4.5
23
+ nvidia-cublas==13.1.0.3
24
+ nvidia-cudnn-cu13==9.19.0.56
25
+ mpmath==1.3.0
26
+ triton==3.6.0
27
+ networkx==3.6.1
28
+ sympy==1.14.0
29
+ torch==2.11.0+cu130
30
+ hf_transfer==0.1.9
31
+ six==1.17.0
32
+ typer==0.25.0
33
+ typing-inspection==0.4.2
34
+ muon-adamw8bit==0.5.0
35
+ aiosignal==1.4.0
36
+ wurlitzer==3.1.1
37
+ semantic-version==2.10.0
38
+ aiohappyeyeballs==2.6.1
39
+ cycler==0.12.1
40
+ tokenizers==0.22.2
41
+ annotated-doc==0.0.4
42
+ rpds-py==0.30.0
43
+ configobj==5.0.9
44
+ regex==2026.4.4
45
+ zipp==3.23.1
46
+ annotated-types==0.7.0
47
+ everett==3.1.0
48
+ pydantic_core==2.46.3
49
+ mdurl==0.1.2
50
+ platformdirs==4.9.6
51
+ idna==3.13
52
+ psutil==7.2.2
53
+ xxhash==3.7.0
54
+ smmap==5.0.3
55
+ frozenlist==1.8.0
56
+ multidict==6.7.1
57
+ shellingham==1.5.4
58
+ kiwisolver==1.5.0
59
+ propcache==0.4.1
60
+ h11==0.16.0
61
+ hf-xet==1.4.3
62
+ pyparsing==3.3.2
63
+ yarl==1.23.0
64
+ importlib_metadata==9.0.0
65
+ referencing==0.37.0
66
+ requests==2.33.1
67
+ filelock==3.29.0
68
+ charset-normalizer==3.4.7
69
+ wrapt==2.1.2
70
+ contourpy==1.3.3
71
+ python-box==6.1.0
72
+ python-dateutil==2.9.0.post0
73
+ packaging==26.2
74
+ httpx==0.28.1
75
+ PyYAML==6.0.3
76
+ click==8.3.3
77
+ jsonschema-specifications==2025.9.1
78
+ gitdb==4.0.12
79
+ einops==0.8.2
80
+ attrs==26.1.0
81
+ httpcore==1.0.9
82
+ cuda-pathfinder==1.5.4
83
+ requests-toolbelt==1.0.0
84
+ GitPython==3.1.48
85
+ jsonschema==4.26.0
86
+ tqdm==4.67.3
87
+ urllib3==2.6.3
88
+ anyio==4.13.0
89
+ simplejson==4.1.1
90
+ multiprocess==0.70.19
91
+ dill==0.4.1
92
+ protobuf==7.34.1
93
+ markdown-it-py==4.0.0
94
+ bitsandbytes==0.49.2
95
+ cuda-bindings==13.2.0
96
+ aiohttp==3.13.5
97
+ accelerate==1.13.0
98
+ dulwich==0.25.2
99
+ pydantic==2.13.3
100
+ datasets==4.8.5
101
+ rich==15.0.0
102
+ flash-linear-attention==0.5.0
103
+ pillow==12.2.0
104
+ huggingface_hub==1.12.0
105
+ sentry-sdk==2.58.0
106
+ fla-core==0.5.0
107
+ Pygments==2.20.0
108
+ diffusers==0.37.1
109
+ fonttools==4.62.1
110
+ comet_ml==3.57.3
111
+ setuptools==81.0.0
112
+ matplotlib==3.10.9
113
+ pyarrow==24.0.0
114
+ wandb==0.26.1
115
+ numpy==2.4.4
116
+ pandas==3.0.2
117
+ transformers==5.6.2
wandb/run-20260428_171645-wt40fdyx/files/wandb-metadata.json ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "os": "Linux-6.8.0-110-generic-x86_64-with-glibc2.39",
3
+ "python": "CPython 3.12.3",
4
+ "startedAt": "2026-04-28T17:16:45.135482Z",
5
+ "args": [
6
+ "--batch",
7
+ "16",
8
+ "--lvl",
9
+ "1"
10
+ ],
11
+ "program": "/root/sdxs-2b/train.py",
12
+ "codePath": "train.py",
13
+ "codePathLocal": "train.py",
14
+ "git": {
15
+ "remote": "https://huggingface.co/AiArtLab/sdxs-2b",
16
+ "commit": "ab8719f79299a6e86448b407298689048767b261"
17
+ },
18
+ "email": "vadim-kulibaba@yandex.ru",
19
+ "root": "/root/sdxs-2b",
20
+ "host": "O-1649582",
21
+ "executable": "/root/.venv/bin/python3",
22
+ "cpu_count": 48,
23
+ "cpu_count_logical": 96,
24
+ "gpu": "NVIDIA GeForce RTX 5090",
25
+ "gpu_count": 1,
26
+ "disk": {
27
+ "/": {
28
+ "total": "888178696192",
29
+ "used": "598432870400"
30
+ }
31
+ },
32
+ "memory": {
33
+ "total": "134889213952"
34
+ },
35
+ "gpu_nvidia": [
36
+ {
37
+ "name": "NVIDIA GeForce RTX 5090",
38
+ "memoryTotal": "34190917632",
39
+ "cudaCores": 21760,
40
+ "architecture": "Blackwell",
41
+ "uuid": "GPU-af06c899-cefd-2303-137f-17f69c648771"
42
+ }
43
+ ],
44
+ "cudaVersion": "13.0",
45
+ "writerId": "9ndk10qtzdsvighcagxlxbtug93n98at"
46
+ }
wandb/run-20260428_171645-wt40fdyx/logs/debug-core.log ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {"time":"2026-04-28T17:16:45.179418861Z","level":"INFO","msg":"main: starting server","port-filename":"/tmp/tmpvlugdolt/port-14112.txt","pid":14112,"detached":false,"idle-timeout":600000000000,"log-level":0,"disable-analytics":false,"shutdown-on-parent-exit":false,"enable-dcgm-profiling":false}
2
+ {"time":"2026-04-28T17:16:45.18135139Z","level":"INFO","msg":"server: will exit if parent process dies","ppid":14112}
3
+ {"time":"2026-04-28T17:16:45.181241106Z","level":"INFO","msg":"server: accepting connections","addr":{"Name":"/tmp/wandb-14112-14129-3488405791/socket","Net":"unix"}}
4
+ {"time":"2026-04-28T17:16:45.343101118Z","level":"INFO","msg":"connection: ManageConnectionData: new connection created","id":"1(@)"}
5
+ {"time":"2026-04-28T17:16:45.350678398Z","level":"INFO","msg":"handleInformInit: received","streamId":"wt40fdyx","id":"1(@)"}
6
+ {"time":"2026-04-28T17:16:45.730308466Z","level":"INFO","msg":"handleInformInit: stream started","streamId":"wt40fdyx","id":"1(@)"}
7
+ {"time":"2026-04-28T17:16:54.001250093Z","level":"INFO","msg":"connection: cancelling request","id":"1(@)","requestId":"lrv3btfsqddl"}
wandb/run-20260428_171645-wt40fdyx/logs/debug-internal.log ADDED
The diff for this file is too large to render. See raw diff
 
wandb/run-20260428_171645-wt40fdyx/logs/debug.log ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 2026-04-28 17:16:45,138 INFO MainThread:14112 [wandb_setup.py:_flush():81] Current SDK version is 0.26.1
2
+ 2026-04-28 17:16:45,138 INFO MainThread:14112 [wandb_setup.py:_flush():81] Configure stats pid to 14112
3
+ 2026-04-28 17:16:45,138 INFO MainThread:14112 [wandb_setup.py:_flush():81] Loading settings from environment variables
4
+ 2026-04-28 17:16:45,138 INFO MainThread:14112 [wandb_init.py:setup_run_log_directory():723] Logging user logs to /root/sdxs-2b/wandb/run-20260428_171645-wt40fdyx/logs/debug.log
5
+ 2026-04-28 17:16:45,138 INFO MainThread:14112 [wandb_init.py:setup_run_log_directory():724] Logging internal logs to /root/sdxs-2b/wandb/run-20260428_171645-wt40fdyx/logs/debug-internal.log
6
+ 2026-04-28 17:16:45,138 INFO MainThread:14112 [wandb_init.py:init():850] calling init triggers
7
+ 2026-04-28 17:16:45,138 INFO MainThread:14112 [wandb_init.py:init():855] wandb.init called with sweep_config: {}
8
+ config: {'batch_size': 16, 'base_learning_rate': 1.3333333333333335e-05, 'num_epochs': 1, 'optimizer_type': 'adafactor', '_wandb': {}}
9
+ 2026-04-28 17:16:45,138 INFO MainThread:14112 [wandb_init.py:init():898] starting backend
10
+ 2026-04-28 17:16:45,343 INFO MainThread:14112 [wandb_init.py:init():913] sending inform_init request
11
+ 2026-04-28 17:16:45,731 INFO MainThread:14112 [wandb_init.py:init():918] backend started and connected
12
+ 2026-04-28 17:16:45,734 INFO MainThread:14112 [wandb_init.py:init():988] updated telemetry
13
+ 2026-04-28 17:16:45,742 INFO MainThread:14112 [wandb_init.py:init():1011] communicating run to backend with 90.0 second timeout
14
+ 2026-04-28 17:16:46,973 INFO MainThread:14112 [wandb_init.py:init():1056] starting run threads in backend
15
+ 2026-04-28 17:16:47,099 INFO MainThread:14112 [wandb_run.py:_console_start():2554] atexit reg
16
+ 2026-04-28 17:16:47,099 INFO MainThread:14112 [wandb_run.py:_redirect():2403] redirect: wrap_raw
17
+ 2026-04-28 17:16:47,100 INFO MainThread:14112 [wandb_run.py:_redirect():2472] Wrapping output streams.
18
+ 2026-04-28 17:16:47,100 INFO MainThread:14112 [wandb_run.py:_redirect():2495] Redirects installed.
19
+ 2026-04-28 17:16:47,104 INFO MainThread:14112 [wandb_init.py:init():1094] run started, returning control to user process
wandb/run-20260428_171645-wt40fdyx/run-wt40fdyx.wandb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:84f6382ddf7402e5b98379478d7897d5a678f939eba1a8a5d028a988674120a5
3
+ size 15499264